This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 567eeed38b [Runtime][Dist] Implementation of KV cache transfer (#17557)
567eeed38b is described below

commit 567eeed38bdbcefb68e36328af6ab1501a81d51e
Author: Hongyi Jin <[email protected]>
AuthorDate: Sun Dec 15 09:56:01 2024 -0500

    [Runtime][Dist] Implementation of KV cache transfer (#17557)
    
    This PR introduces kv transfer kernel and KV cache integration used
    in prefill-decode disaggregation.
    
    Co-authored-by: Ruihang Lai <[email protected]>
    Co-authored-by: Charlie Ruan 
<[email protected]>
    Co-authored-by: Yingcheng Wang 
<[email protected]>
---
 3rdparty/flashinfer                                |   2 +-
 CMakeLists.txt                                     |   4 +-
 docs/how_to/tutorials/optimize_llm.py              |   1 +
 python/tvm/relax/frontend/nn/llm/kv_cache.py       |  33 +-
 src/runtime/contrib/nvshmem/init.cc                |  58 ++-
 src/runtime/contrib/nvshmem/kv_transfer.cu         | 333 ++++++++++++
 src/runtime/contrib/nvshmem/memory_allocator.cc    |   3 +-
 src/runtime/disco/nccl/nccl.cc                     |  10 +-
 src/runtime/relax_vm/kv_state.cc                   |   4 +
 src/runtime/relax_vm/kv_state.h                    |   8 +
 src/runtime/relax_vm/paged_kv_cache.cc             | 385 +++++++++++++-
 tests/python/disco/test_nvshmem.py                 |   4 +-
 .../test_runtime_builtin_kv_cache_transfer.py}     | 565 +++++----------------
 ...est_runtime_builtin_kv_cache_transfer_kernel.py | 252 +++++++++
 ...runtime_builtin_paged_attention_kv_cache_tir.py |   1 +
 15 files changed, 1198 insertions(+), 465 deletions(-)

diff --git a/3rdparty/flashinfer b/3rdparty/flashinfer
index 1e379898a5..a76ceedb94 160000
--- a/3rdparty/flashinfer
+++ b/3rdparty/flashinfer
@@ -1 +1 @@
-Subproject commit 1e379898a589cdd4ff18a4621fcbe18d63501545
+Subproject commit a76ceedb9495d3d05648c29a8e6bb45baa265f6c
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 8abdfad24c..757b0d1a89 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -478,7 +478,9 @@ if (USE_CUDA AND USE_NVSHMEM)
   if (NOT NVSHMEM_FOUND)
     message(FATAL_ERROR "Cannot find NVSHMEM, USE_NVSHMEM=" ${USE_NVSHMEM})
   endif()
-  tvm_file_glob(GLOB RUNTIME_NVSHMEM_SRCS src/runtime/contrib/nvshmem/*.cc)
+  set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)
+  set(CMAKE_POSITION_INDEPENDENT_CODE ON)
+  tvm_file_glob(GLOB RUNTIME_NVSHMEM_SRCS src/runtime/contrib/nvshmem/*.cc 
src/runtime/contrib/nvshmem/*.cu)
   list(APPEND RUNTIME_SRCS ${RUNTIME_NVSHMEM_SRCS})
 endif()
 
diff --git a/docs/how_to/tutorials/optimize_llm.py 
b/docs/how_to/tutorials/optimize_llm.py
index 9311c0557f..c30b2c381c 100644
--- a/docs/how_to/tutorials/optimize_llm.py
+++ b/docs/how_to/tutorials/optimize_llm.py
@@ -303,6 +303,7 @@ class LlamaForCasualLM(nn.Module):
             rotary_dim=self.head_dim,
             dtype=self.dtype,
             target=target,
+            enable_disaggregation=False,
         )
 
     def get_default_spec(self):
diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py 
b/python/tvm/relax/frontend/nn/llm/kv_cache.py
index 18f3e19909..f60c40efa2 100644
--- a/python/tvm/relax/frontend/nn/llm/kv_cache.py
+++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py
@@ -169,6 +169,7 @@ class FlashInferPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-me
         rope_scaling: Dict[str, Any],
         rope_ext_factors: rx.Expr,
         rotary_dim: int,
+        enable_disaggregation: bool,
         dtype: str,
         target: Target,
         name: str = "paged_kv_cache",
@@ -214,6 +215,8 @@ class FlashInferPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-me
             The RoPE extension factors when "longrope" mode RoPE scaling is 
enabled.
         rotary_dim : int
             The number of dimensions in the embedding that RoPE is applied to.
+        enable_disaggregation : bool
+            Whether to enable disaggregation in the KV cache.
         """
         if rope_mode == RopeMode.INLINE:
             assert rotary_dim == head_dim, "FlashInfer RoPE does not support 
partial rotary dim."
@@ -259,6 +262,7 @@ class FlashInferPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-me
             bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, 
head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask"),
             bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, 
num_attention_heads, head_dim, dtype, rope_scaling, target), 
"tir_attention_prefill_with_tree_mask_with_paged_kv_cache"),
             rope_ext_factors,
+            rx.PrimValue(enable_disaggregation),
             # fmt: on
             # pylint: enable=line-too-long
         ]
@@ -293,6 +297,7 @@ class TIRPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-methods
         rope_scaling: Dict[str, Any],
         rope_ext_factors: rx.Expr,
         rotary_dim: int,
+        enable_disaggregation: bool,
         dtype: str,
         target: Target,
         name: str = "paged_kv_cache",
@@ -338,6 +343,8 @@ class TIRPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-methods
             The RoPE extension factors when "longrope" mode RoPE scaling is 
enabled.
         rotary_dim : int
             The number of dimensions in the embedding that RoPE is applied to.
+        enable_disaggregation : bool
+            Whether to enable disaggregation in the KV cache.
         target : Target
             The target to build the model to.
         """
@@ -377,6 +384,7 @@ class TIRPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-methods
             bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, 
head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask"),
             bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, 
num_attention_heads, head_dim, dtype, rope_scaling, target), 
"tir_attention_prefill_with_tree_mask_with_paged_kv_cache"),
             rope_ext_factors,
+            rx.PrimValue(enable_disaggregation),
             # fmt: on
             # pylint: enable=line-too-long
         ]
@@ -409,8 +417,9 @@ def _kv_cache_transpose_append(num_key_value_heads, 
head_dim, dtype):
         T.func_attr({"tir.noalias": T.bool(True)})
         ntoken = T.SizeVar("num_tokens_excluding_cache", "int64")
         num_pages = T.int64()
+        pages_elem_offset = T.int64()
         position_map_elem_offset = T.int32()
-        pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, 
16, head_dim), dtype)
+        pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, 
16, head_dim), dtype, elem_offset=pages_elem_offset)
         k_data = T.match_buffer(var_k_data, (ntoken, num_key_value_heads, 
head_dim), dtype)
         v_data = T.match_buffer(var_v_data, (ntoken, num_key_value_heads, 
head_dim), dtype)
         position_map = T.match_buffer(
@@ -453,8 +462,9 @@ def _kv_cache_debug_get_kv(num_hidden_layers, 
num_key_value_heads, head_dim, dty
         seqlen = T.SizeVar("num_tokens_including_cache", "int64")
         page_size = T.SizeVar("page_size", "int64")
         num_pages = T.int64()
+        pages_elem_offset = T.int64()
         position_map_elem_offset = T.int64()
-        pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, 
page_size, head_dim), dtype)
+        pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, 
page_size, head_dim), dtype,elem_offset=pages_elem_offset)
         position_map = T.match_buffer(
             var_position_map, (seqlen,), "int32", 
elem_offset=position_map_elem_offset
         )
@@ -594,6 +604,7 @@ def _attention_prefill(
         total_len = T.int32(is_size_var=True)
         nnz_pages = T.int32(is_size_var=True)
         max_num_pages = T.int32(is_size_var=True)
+        pages_elem_offset = T.int64(is_size_var=True)
         q_indptr_elem_offset = T.int32(is_size_var=True)
         page_indptr_elem_offset = T.int32(is_size_var=True)
         page_values_elem_offset = T.int32(is_size_var=True)
@@ -603,7 +614,7 @@ def _attention_prefill(
 
         q = T.match_buffer(var_q, (total_len, h_q, d), dtype)
         q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", 
elem_offset=q_indptr_elem_offset)
-        pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), 
dtype)
+        pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), 
dtype, elem_offset=pages_elem_offset)
         page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), 
"int32", elem_offset=page_indptr_elem_offset)
         page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", 
elem_offset=page_values_elem_offset)
         k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, 
(batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset)
@@ -975,6 +986,7 @@ def _attention_decode(
         B = T.int32(is_size_var=True)
         nnz_pages = T.int32(is_size_var=True)
         max_num_pages = T.int32(is_size_var=True)
+        pages_elem_offset = T.int64(is_size_var=True)
         page_indptr_elem_offset = T.int32(is_size_var=True)
         page_values_elem_offset = T.int32(is_size_var=True)
         k_rope_pos_offset_elem_offset = T.int32(is_size_var=True)
@@ -983,7 +995,7 @@ def _attention_decode(
 
         Q = T.match_buffer(Q_handle, (B, H_qo, D), qkv_dtype)
         pages = T.match_buffer(
-            pages_handle, (max_num_pages, 2, H_kv, 16, D), qkv_dtype
+            pages_handle, (max_num_pages, 2, H_kv, 16, D), qkv_dtype, 
elem_offset=pages_elem_offset
         )
         page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), 
"int32", elem_offset=page_indptr_elem_offset)
         page_table_values = T.match_buffer(page_table_values_handle, 
(nnz_pages,), "int32", elem_offset=page_values_elem_offset)
@@ -1949,7 +1961,13 @@ def _copy_single_page(num_heads, page_size, head_dim, 
dtype, target: Target):
     ):
         T.func_attr({"tir.is_scheduled": 1})
         num_pages = T.int32()
-        pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, page_size, 
head_dim), dtype)
+        pages_elem_offset = T.int64()
+        pages = T.match_buffer(
+            var_pages,
+            (num_pages, 2, num_heads, page_size, head_dim),
+            dtype,
+            elem_offset=pages_elem_offset,
+        )
 
         for b in T.thread_binding(
             (copy_length * num_heads * head_dim + tx - 1) // tx, 
thread="blockIdx.x"
@@ -1993,7 +2011,10 @@ def _compact_kv_copy(num_heads, head_dim, dtype, target: 
Target):
         total_copy_length = T.int32()
         copy_length_indptr_elem_offset = T.int32()
         copy_src_dst_pos_elem_offset = T.int32()
-        pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, 16, 
head_dim), dtype)
+        pages_elem_offset = T.int64()
+        pages = T.match_buffer(
+            var_pages, (num_pages, 2, num_heads, 16, head_dim), dtype, 
elem_offset=pages_elem_offset
+        )
         copy_length_indptr = T.match_buffer(
             var_copy_length_indptr,
             (batch_size + 1,),
diff --git a/src/runtime/contrib/nvshmem/init.cc 
b/src/runtime/contrib/nvshmem/init.cc
index 50fdde4c49..33a787b5b9 100644
--- a/src/runtime/contrib/nvshmem/init.cc
+++ b/src/runtime/contrib/nvshmem/init.cc
@@ -18,6 +18,7 @@
  */
 #include <nvshmem.h>
 #include <nvshmemx.h>
+#include <picojson.h>
 #include <tvm/runtime/disco/disco_worker.h>
 #include <tvm/runtime/packed_func.h>
 #include <tvm/runtime/registry.h>
@@ -38,9 +39,14 @@ ShapeTuple InitNVSHMEMUID() {
   return ShapeTuple(uid_64);
 }
 
-void InitNVSHMEM(ShapeTuple uid_64, int num_workers) {
-  DiscoWorker* worker = DiscoWorker::ThreadLocal();
-  ICHECK(worker != nullptr);
+void InitNVSHMEM(ShapeTuple uid_64, int num_workers, int worker_id_start) {
+  DiscoWorker* worker = ThreadLocalDiscoWorker::Get()->worker;
+  int worker_id;
+  if (worker == nullptr) {
+    worker_id = worker_id_start;
+  } else {
+    worker_id = worker_id_start + worker->worker_id;
+  }
   CHECK_EQ(uid_64.size(), UNIQUEID_PADDING + 1)
       << "ValueError: The length of unique_id must be " << UNIQUEID_PADDING << 
", but got "
       << uid_64.size() << ".";
@@ -52,17 +58,61 @@ void InitNVSHMEM(ShapeTuple uid_64, int num_workers) {
   for (int i = 0; i < UNIQUEID_PADDING; ++i) {
     uid.internal[i] = static_cast<char>(uid_64[i + 1]);
   }
-  nvshmemx_set_attr_uniqueid_args(worker->worker_id, num_workers, &uid, &attr);
+  // FIXME: this is a hack to avoid the issue of NVSHMEM using 
Multi-process-per-GPU to initialize
+  cudaSetDevice(worker_id);
+  nvshmemx_set_attr_uniqueid_args(worker_id, num_workers, &uid, &attr);
   nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);
   int mype_node = nvshmem_team_my_pe(NVSHMEMX_TEAM_NODE);
   CUDA_CALL(cudaSetDevice(mype_node));
+  if (worker != nullptr) {
+    if (worker->default_device.device_type == DLDeviceType::kDLCPU) {
+      worker->default_device = Device{DLDeviceType::kDLCUDA, mype_node};
+    } else {
+      ICHECK(worker->default_device.device_type == DLDeviceType::kDLCUDA &&
+             worker->default_device.device_id == mype_node)
+          << "The default device of the worker is inconsistent with the device 
used for NVSHMEM. "
+          << "The default device is " << worker->default_device
+          << ", but the device used for NVSHMEM is " << 
Device{DLDeviceType::kDLCUDA, mype_node}
+          << ".";
+    }
+  }
   LOG_INFO << "NVSHMEM init finished: mype=" << nvshmem_my_pe() << " "
            << ", npes=" << nvshmem_n_pes();
 }
 
+void InitNVSHMEMWrapper(String args) {
+  picojson::value v;
+  std::string err = picojson::parse(v, args);
+  if (!err.empty()) {
+    LOG(FATAL) << "JSON parse error: " << err;
+  }
+
+  if (!v.is<picojson::object>()) {
+    LOG(FATAL) << "JSON is not an object";
+  }
+
+  picojson::object& obj = v.get<picojson::object>();
+
+  picojson::array uid_array = obj["uid"].get<picojson::array>();
+  std::vector<int64_t> uid_vector;
+  for (const auto& elem : uid_array) {
+    uid_vector.push_back(elem.get<int64_t>());
+  }
+
+  ShapeTuple uid_64(uid_vector);
+
+  int num_workers = static_cast<int>(obj["npes"].get<int64_t>());
+  int worker_id_start = static_cast<int>(obj["pe_start"].get<int64_t>());
+
+  InitNVSHMEM(uid_64, num_workers, worker_id_start);
+}
+
 
TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_uid").set_body_typed(InitNVSHMEMUID);
 
 
TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem").set_body_typed(InitNVSHMEM);
 
+TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_wrapper")
+    .set_body_typed(InitNVSHMEMWrapper);
+
 }  // namespace runtime
 }  // namespace tvm
diff --git a/src/runtime/contrib/nvshmem/kv_transfer.cu 
b/src/runtime/contrib/nvshmem/kv_transfer.cu
new file mode 100644
index 0000000000..cf3a9958f8
--- /dev/null
+++ b/src/runtime/contrib/nvshmem/kv_transfer.cu
@@ -0,0 +1,333 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <cuda_fp16.h>
+#include <dlpack/dlpack.h>
+#include <nvshmem.h>
+#include <tvm/runtime/disco/disco_worker.h>
+#include <tvm/runtime/logging.h>
+#include <tvm/runtime/registry.h>
+
+template <int dim>
+__device__ int64_t calc_flattened_index(int shape[dim], int index[dim]) {
+  int64_t flattened_index = 0;
+#pragma unroll
+  for (int i = 0; i < dim; i++) {
+    flattened_index *= shape[i];
+    flattened_index += index[i];
+  }
+  return flattened_index;
+}
+
+template <typename T, int local_num_kv_head, int remote_num_kv_head, int 
head_dim, int page_size>
+__global__ void KVTransfer(T* pages, T* k_data, T* v_data, int32_t* 
remote_position_map,
+                           int ntokens, int local_tp_rank, int32_t* 
remote_tp_group_pe_offset,
+                           int remote_num_pages) {
+  // launch grid: [num_blocks, 1, 1], [32, local_num_kv_head, 1]
+  // pages(remote): [remote_num_pages, 2, remote_num_kv_head, page_size, 
head_dim]
+  // k_data: [ntokens, local_num_kv_head, head_dim]
+  // v_data: [ntokens, local_num_kv_head, head_dim]
+  int remote_pe;
+  int remote_kv_head_index;
+  int h = threadIdx.y;  // local kv head index
+
+  for (int global_pos = blockIdx.x; global_pos < ntokens; global_pos += 
gridDim.x) {
+    int position = remote_position_map[global_pos];
+    if (position == -1) {
+      continue;
+    }
+    if (local_num_kv_head <= remote_num_kv_head) {
+      // gather
+      assert(remote_num_kv_head % local_num_kv_head == 0);
+      int gather_factor = remote_num_kv_head / local_num_kv_head;
+      remote_pe = remote_tp_group_pe_offset[global_pos] + local_tp_rank / 
gather_factor;
+      remote_kv_head_index = (local_tp_rank % gather_factor) * 
local_num_kv_head + h;
+    } else {
+      // scatter
+      assert(local_num_kv_head % remote_num_kv_head == 0);
+      int scatter_factor = local_num_kv_head / remote_num_kv_head;
+      remote_pe = remote_tp_group_pe_offset[global_pos] + local_tp_rank * 
scatter_factor +
+                  h / remote_num_kv_head;
+      remote_kv_head_index = h % remote_num_kv_head;
+    }
+    int page_id = position / page_size;
+    int offset_in_page = position % page_size;
+    int pages_shape[5] = {remote_num_pages, 2, remote_num_kv_head, page_size, 
head_dim};
+    int k_page_index[5] = {page_id, 0, remote_kv_head_index, offset_in_page, 
0};
+    int v_page_index[5] = {page_id, 1, remote_kv_head_index, offset_in_page, 
0};
+    int k_v_shape[3] = {ntokens, local_num_kv_head, head_dim};
+    int k_v_index[3] = {global_pos, h, 0};
+    nvshmemx_putmem_nbi_warp(pages + calc_flattened_index<5>(pages_shape, 
k_page_index),
+                             k_data + calc_flattened_index<3>(k_v_shape, 
k_v_index),
+                             head_dim * sizeof(T), remote_pe);
+    nvshmemx_putmem_nbi_warp(pages + calc_flattened_index<5>(pages_shape, 
v_page_index),
+                             v_data + calc_flattened_index<3>(k_v_shape, 
k_v_index),
+                             head_dim * sizeof(T), remote_pe);
+  }
+  if (threadIdx.x == 0) {
+    nvshmem_quiet();
+  }
+}
+template <typename T, int local_num_kv_head, int remote_num_kv_head, int 
head_dim, int page_size>
+__global__ void KVTransferPageToPage(T* remote_pages, T* local_pages, int32_t* 
remote_position_map,
+                                     int32_t* local_position_map, int ntokens, 
int local_tp_rank,
+                                     int32_t* remote_tp_group_pe_offset) {
+  // launch grid: [num_blocks, 1, 1], [32, local_num_kv_head, 1]
+  int remote_pe;
+  int remote_kv_head_index;
+  int h = threadIdx.y;  // local kv head index
+  int is_k = threadIdx.z;
+
+  for (int global_pos = blockIdx.x; global_pos < ntokens; global_pos += 
gridDim.x) {
+    int remote_position = remote_position_map[global_pos];
+    int local_position = local_position_map[global_pos];
+    if (remote_position == -1 || local_position == -1) {
+      continue;
+    }
+    if (local_num_kv_head <= remote_num_kv_head) {
+      // gather
+      assert(remote_num_kv_head % local_num_kv_head == 0);
+      int gather_factor = remote_num_kv_head / local_num_kv_head;
+      remote_pe = remote_tp_group_pe_offset[global_pos] + local_tp_rank / 
gather_factor;
+      remote_kv_head_index = (local_tp_rank % gather_factor) * 
local_num_kv_head + h;
+    } else {
+      // scatter
+      assert(local_num_kv_head % remote_num_kv_head == 0);
+      int scatter_factor = local_num_kv_head / remote_num_kv_head;
+      remote_pe = remote_tp_group_pe_offset[global_pos] + local_tp_rank * 
scatter_factor +
+                  h / remote_num_kv_head;
+      remote_kv_head_index = h % remote_num_kv_head;
+    }
+
+    int remote_page_id = remote_position / page_size;
+    int remote_offset_in_page = remote_position % page_size;
+    int local_page_id = local_position / page_size;
+    int local_offset_in_page = local_position % page_size;
+    int remote_pages_shape[5] = {1, 2, remote_num_kv_head, page_size, 
head_dim};
+    int local_pages_shape[5] = {1, 2, local_num_kv_head, page_size, head_dim};
+    int remote_page_index[5] = {remote_page_id, is_k, remote_kv_head_index, 
remote_offset_in_page,
+                                0};
+    int local_page_index[5] = {local_page_id, is_k, h, local_offset_in_page, 
0};
+    nvshmemx_putmem_nbi_warp(
+        remote_pages + calc_flattened_index<5>(remote_pages_shape, 
remote_page_index),
+        local_pages + calc_flattened_index<5>(local_pages_shape, 
local_page_index),
+        head_dim * sizeof(T), remote_pe);
+  }
+  if (threadIdx.x == 0) {
+    nvshmem_quiet();
+  }
+}
+
+#define DISPATCH_TVM_CUDA_DTYPE(dl_dtype, cuda_dtype, ...)   \
+  if (dl_dtype.code == kDLFloat && dl_dtype.bits == 16) {    \
+    using cuda_dtype = half;                                 \
+    __VA_ARGS__                                              \
+  } else {                                                   \
+    LOG(FATAL) << "Unsupported data type " << dl_dtype.code; \
+  }
+
+#define DISPATCH_HEAD_DIM(head_dim, const_head_dim, ...) \
+  if (head_dim == 128) {                                 \
+    constexpr int const_head_dim = 128;                  \
+    __VA_ARGS__                                          \
+  } else {                                               \
+    LOG(FATAL) << "Unsupported head dim " << head_dim;   \
+  }
+
+#define DISPATCH_PAGE_SIZE(page_size, const_page_size, ...) \
+  if (page_size == 16) {                                    \
+    constexpr int const_page_size = 16;                     \
+    __VA_ARGS__                                             \
+  } else if (page_size == 4) {                              \
+    constexpr int const_page_size = 4;                      \
+    __VA_ARGS__                                             \
+  } else {                                                  \
+    LOG(FATAL) << "Unsupported page size " << page_size;    \
+  }
+
+#define DISPATCH_NUM_KV_HEAD(num_kv_head, const_num_kv_head, ...) \
+  if (num_kv_head == 1) {                                         \
+    constexpr int const_num_kv_head = 1;                          \
+    __VA_ARGS__                                                   \
+  } else if (num_kv_head == 2) {                                  \
+    constexpr int const_num_kv_head = 2;                          \
+    __VA_ARGS__                                                   \
+  } else if (num_kv_head == 4) {                                  \
+    constexpr int const_num_kv_head = 4;                          \
+    __VA_ARGS__                                                   \
+  } else if (num_kv_head == 8) {                                  \
+    constexpr int const_num_kv_head = 8;                          \
+    __VA_ARGS__                                                   \
+  } else {                                                        \
+    LOG(FATAL) << "Unsupported num_kv_head " << num_kv_head;      \
+  }
+
+int _KVTransfer(DLTensor* remote_pages, DLTensor* k, DLTensor* v, DLTensor* 
remote_position_map,
+                DLTensor* remote_tp_group_pe_offset, TVMStreamHandle 
transfer_stream) {
+  CHECK_EQ(remote_pages->device.device_type, kDLCUDA)
+      << "The device of remote_pages matrix must be CUDA.";
+  CHECK_EQ(k->device.device_type, kDLCUDA) << "The device of k matrix must be 
CUDA.";
+  CHECK_EQ(v->device.device_type, kDLCUDA) << "The device of v matrix must be 
CUDA.";
+  CHECK_EQ(remote_position_map->device.device_type, kDLCUDA)
+      << "The device of remote_position_map matrix must be CUDA.";
+  size_t dev_id = remote_pages->device.device_id;
+  CHECK_EQ(k->device.device_id, dev_id)
+      << "The device id of remote_pages and k matrix doesn't match.";
+  CHECK_EQ(v->device.device_id, dev_id)
+      << "The device id of remote_pages and v matrix doesn't match.";
+  CHECK_EQ(remote_position_map->device.device_id, dev_id)
+      << "The device id of remote_pages and remote_position_map matrix doesn't 
match.";
+  CHECK_EQ(remote_tp_group_pe_offset->device.device_id, dev_id)
+      << "The device id of remote_pages and remote_tp_group_pe_offset matrix 
doesn't match.";
+
+  CHECK_EQ(remote_pages->ndim, 5);
+  int remote_num_pages = remote_pages->shape[0];
+  int remote_num_kv_head = remote_pages->shape[2];
+  int page_size = remote_pages->shape[3];
+  int head_dim = remote_pages->shape[4];
+
+  CHECK_GE(k->ndim, 3);
+  int kv_len = k->shape[k->ndim - 3];
+  int local_num_kv_heads = k->shape[k->ndim - 2];
+  CHECK_EQ(head_dim, k->shape[k->ndim - 1]);
+
+  CHECK_GE(v->ndim, 3);
+  CHECK_EQ(kv_len, v->shape[v->ndim - 3]);
+  CHECK_EQ(local_num_kv_heads, v->shape[v->ndim - 2]);
+  CHECK_EQ(head_dim, v->shape[v->ndim - 1]);
+
+  CHECK(remote_pages->dtype.lanes == 1 && k->dtype.lanes == 1 && 
v->dtype.lanes == 1);
+  CHECK(remote_pages->dtype.bits == k->dtype.bits && remote_pages->dtype.code 
== k->dtype.code);
+  CHECK(remote_pages->dtype.bits == v->dtype.bits && remote_pages->dtype.code 
== v->dtype.code);
+  int local_tp_rank;
+  tvm::runtime::DiscoWorker* worker = 
tvm::runtime::ThreadLocalDiscoWorker::Get()->worker;
+  if (worker == nullptr) {
+    local_tp_rank = 0;
+  } else {
+    local_tp_rank = worker->worker_id;
+  }
+
+  dim3 blocks(8, 1, 1);
+  dim3 threads(32, local_num_kv_heads, 1);
+  DISPATCH_TVM_CUDA_DTYPE(
+      remote_pages->dtype, dtype_in,
+      {DISPATCH_HEAD_DIM(
+          head_dim, HEAD_DIM,
+          {DISPATCH_PAGE_SIZE(
+              page_size, PAGE_SIZE,
+              {DISPATCH_NUM_KV_HEAD(
+                  remote_num_kv_head, REMOTE_NUM_KV_HEAD,
+                  {DISPATCH_NUM_KV_HEAD(local_num_kv_heads, LOCAL_NUM_KV_HEAD, 
{
+                    dtype_in* remote_pages_data = reinterpret_cast<dtype_in*>(
+                        reinterpret_cast<char*>(remote_pages->data) + 
remote_pages->byte_offset);
+                    dtype_in* k_data = reinterpret_cast<dtype_in*>(
+                        reinterpret_cast<char*>(k->data) + k->byte_offset);
+                    dtype_in* v_data = reinterpret_cast<dtype_in*>(
+                        reinterpret_cast<char*>(v->data) + v->byte_offset);
+                    int32_t* remote_position_map_data = 
reinterpret_cast<int32_t*>(
+                        reinterpret_cast<char*>(remote_position_map->data) +
+                        remote_position_map->byte_offset);
+                    int32_t* remote_tp_group_pe_offset_data = 
reinterpret_cast<int32_t*>(
+                        
reinterpret_cast<char*>(remote_tp_group_pe_offset->data) +
+                        remote_tp_group_pe_offset->byte_offset);
+                    KVTransfer<dtype_in, LOCAL_NUM_KV_HEAD, 
REMOTE_NUM_KV_HEAD, HEAD_DIM, PAGE_SIZE>
+                        <<<blocks, threads, 0, 
static_cast<cudaStream_t>(transfer_stream)>>>(
+                            remote_pages_data, k_data, v_data, 
remote_position_map_data, kv_len,
+                            local_tp_rank, remote_tp_group_pe_offset_data, 
remote_num_pages);
+                  })})})})})
+
+  return 0;
+}
+
+int _KVTransferPageToPage(DLTensor* remote_pages, DLTensor* local_pages,
+                          DLTensor* remote_position_map, DLTensor* 
local_position_map,
+                          DLTensor* remote_tp_group_pe_offset, TVMStreamHandle 
transfer_stream) {
+  CHECK_EQ(remote_pages->device.device_type, kDLCUDA)
+      << "The device of remote_pages matrix must be CUDA.";
+  CHECK_EQ(local_pages->device.device_type, kDLCUDA) << "The device of k 
matrix must be CUDA.";
+  CHECK_EQ(remote_position_map->device.device_type, kDLCUDA)
+      << "The device of remote_position_map matrix must be CUDA.";
+  size_t dev_id = remote_pages->device.device_id;
+  CHECK_EQ(local_pages->device.device_id, dev_id)
+      << "The device id of remote_pages and k matrix doesn't match.";
+  CHECK_EQ(remote_position_map->device.device_id, dev_id)
+      << "The device id of remote_pages and remote_position_map matrix doesn't 
match.";
+  CHECK_EQ(remote_tp_group_pe_offset->device.device_id, dev_id)
+      << "The device id of remote_pages and remote_tp_group_pe_offset matrix 
doesn't match.";
+
+  CHECK_EQ(remote_pages->ndim, 5);
+  int remote_num_kv_head = remote_pages->shape[2];
+  int page_size = remote_pages->shape[3];
+  int head_dim = remote_pages->shape[4];
+
+  CHECK_GE(local_pages->ndim, 5);
+  int local_num_kv_heads = local_pages->shape[2];
+  CHECK_EQ(head_dim, local_pages->shape[4]);
+
+  CHECK_EQ(remote_position_map->ndim, 1);
+  int ntokens = remote_position_map->shape[0];
+
+  CHECK(remote_pages->dtype.lanes == 1 && local_pages->dtype.lanes == 1);
+  CHECK(remote_pages->dtype.bits == local_pages->dtype.bits &&
+        remote_pages->dtype.code == local_pages->dtype.code);
+
+  int local_tp_rank;
+  tvm::runtime::DiscoWorker* worker = 
tvm::runtime::ThreadLocalDiscoWorker::Get()->worker;
+  if (worker == nullptr) {
+    local_tp_rank = 0;
+  } else {
+    local_tp_rank = worker->worker_id;
+  }
+
+  dim3 blocks(8, 1, 1);
+  dim3 threads(32, local_num_kv_heads, 2);
+  DISPATCH_TVM_CUDA_DTYPE(
+      remote_pages->dtype, dtype_in,
+      {DISPATCH_HEAD_DIM(
+          head_dim, HEAD_DIM,
+          {DISPATCH_PAGE_SIZE(
+              page_size, PAGE_SIZE,
+              {DISPATCH_NUM_KV_HEAD(
+                  remote_num_kv_head, REMOTE_NUM_KV_HEAD,
+                  {DISPATCH_NUM_KV_HEAD(local_num_kv_heads, LOCAL_NUM_KV_HEAD, 
{
+                    dtype_in* remote_pages_data = reinterpret_cast<dtype_in*>(
+                        reinterpret_cast<char*>(remote_pages->data) + 
remote_pages->byte_offset);
+                    dtype_in* local_pages_data = reinterpret_cast<dtype_in*>(
+                        reinterpret_cast<char*>(local_pages->data) + 
local_pages->byte_offset);
+                    int32_t* remote_position_map_data = 
reinterpret_cast<int32_t*>(
+                        reinterpret_cast<char*>(remote_position_map->data) +
+                        remote_position_map->byte_offset);
+                    int32_t* local_position_map_data = 
reinterpret_cast<int32_t*>(
+                        reinterpret_cast<char*>(local_position_map->data) +
+                        local_position_map->byte_offset);
+                    int32_t* remote_tp_group_pe_offset_data = 
reinterpret_cast<int32_t*>(
+                        
reinterpret_cast<char*>(remote_tp_group_pe_offset->data) +
+                        remote_tp_group_pe_offset->byte_offset);
+                    KVTransferPageToPage<dtype_in, LOCAL_NUM_KV_HEAD, 
REMOTE_NUM_KV_HEAD, HEAD_DIM,
+                                         PAGE_SIZE>
+                        <<<blocks, threads, 0, 
static_cast<cudaStream_t>(transfer_stream)>>>(
+                            remote_pages_data, local_pages_data, 
remote_position_map_data,
+                            local_position_map_data, ntokens, local_tp_rank,
+                            remote_tp_group_pe_offset_data);
+                  })})})})})
+
+  return 0;
+}
+
+TVM_REGISTER_GLOBAL("nvshmem.KVTransfer").set_body_typed(_KVTransfer);
+TVM_REGISTER_GLOBAL("nvshmem.KVTransferPageToPage").set_body_typed(_KVTransferPageToPage);
diff --git a/src/runtime/contrib/nvshmem/memory_allocator.cc 
b/src/runtime/contrib/nvshmem/memory_allocator.cc
index 89d56ed3dc..4380c7e65d 100644
--- a/src/runtime/contrib/nvshmem/memory_allocator.cc
+++ b/src/runtime/contrib/nvshmem/memory_allocator.cc
@@ -25,6 +25,7 @@
 #include <thread>
 
 #include "../../cuda/cuda_common.h"
+#include "../../disco/utils.h"
 #include "../../memory/pooled_allocator.h"
 
 namespace tvm {
@@ -88,7 +89,7 @@ class NVSHMEMAllocator final : public PooledAllocator {
 };
 
 NDArray NVSHMEMEmpty(ShapeTuple shape, DataType dtype, Device device) {
-  return NVSHMEMAllocator::Global()->Empty(shape, dtype, device);
+  return NVSHMEMAllocator::Global()->Empty(shape, dtype, 
UseDefaultDeviceIfNone(device));
 }
 
 
TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.empty").set_body_typed(NVSHMEMEmpty);
diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc
index 6ee54e14f3..75e7db483e 100644
--- a/src/runtime/disco/nccl/nccl.cc
+++ b/src/runtime/disco/nccl/nccl.cc
@@ -93,7 +93,15 @@ void InitCCLPerWorker(IntTuple device_ids, std::string 
unique_id_bytes) {
   StreamCreate(&ctx->default_stream);
 #endif
   Device device{TVM_DISCO_DEVICE_TYPE, device_id};
-  worker->default_device = device;
+  if (worker->default_device.device_type == DLDeviceType::kDLCPU) {
+    worker->default_device = device;
+  } else {
+    ICHECK(worker->default_device.device_type == device.device_type &&
+           worker->default_device.device_id == device.device_id)
+        << "The default device of the worker is inconsistent with the device 
used for CCL. "
+        << "The default device is " << worker->default_device << ", but the 
device used for CCL is "
+        << device << ".";
+  }
   worker->ccl = TVM_DISCO_CCL_NAME;
   ctx->worker = worker;
   ctx->device_id = device_id;
diff --git a/src/runtime/relax_vm/kv_state.cc b/src/runtime/relax_vm/kv_state.cc
index b730a4eb07..67afda3bfd 100644
--- a/src/runtime/relax_vm/kv_state.cc
+++ b/src/runtime/relax_vm/kv_state.cc
@@ -56,6 +56,10 @@ TVM_REGISTER_GLOBAL("vm.builtin.kv_state_end_forward")
     .set_body_method<KVState>(&KVStateObj::EndForward);
 
 // Attention KV Cache methods
+TVM_REGISTER_GLOBAL("vm.builtin.kv_cache_disagg_prepare_recv")
+    
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::DisaggPrepareRecv);
+TVM_REGISTER_GLOBAL("vm.builtin.kv_cache_disagg_mark_send")
+    .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::DisaggMarkSend);
 
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_enable_sliding_window_for_seq")
     
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::EnableSlidingWindowForSeq);
 
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes")
diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h
index 6d30ce998a..7df3215d08 100644
--- a/src/runtime/relax_vm/kv_state.h
+++ b/src/runtime/relax_vm/kv_state.h
@@ -157,6 +157,14 @@ class AttentionKVCacheObj : public KVStateObj {
   virtual void CommitAcceptedTokenTreeNodes(const IntTuple& seq_ids,
                                             const IntTuple& leaf_indices) = 0;
 
+  /*! \brief Prepare for the disaggregation KV data receive for the specified 
sequence and length.*/
+  virtual IntTuple DisaggPrepareRecv(int64_t seq_id, int length) = 0;
+
+  /*! \brief Mark which tokens' KV cache needs to be sent to other devices */
+  virtual void DisaggMarkSend(int64_t seq_id, int64_t begin,
+                              const IntTuple& compressed_remote_position_map,
+                              int32_t recver_pe_offset) = 0;
+
   /************** Attention **************/
 
   /*!
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc 
b/src/runtime/relax_vm/paged_kv_cache.cc
index b6636ae1a7..81c55bfcb6 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -129,6 +129,13 @@ struct Block {
   }
 };
 
+struct KVTransferMetadata {
+  int64_t start = std::numeric_limits<int64_t>::max();
+  std::vector<int64_t> remote_position_map;
+  int32_t recver_pe_offset = -1;
+  std::vector<int64_t> local_position_map;
+};
+
 /*!
  * \brief The sequence structure in paged KV cache with common prefix support.
  * Each sequence contains one or more blocks to support common prefix.
@@ -163,6 +170,8 @@ struct Sequence {
   std::vector<int32_t> token_tree_parent_ptr;
   /*! \brief The depth of each node in the token tree. */
   std::vector<int32_t> token_tree_node_depths;
+  /*! \brief The metadata of kv transfer*/
+  KVTransferMetadata kv_transfer_metadata;
   /*!
    * \brief A boolean denoting whether the accepted token tree indices of
    * this sequence are committed
@@ -231,7 +240,13 @@ class HostMemoryVector {
   }
 
   void push_back(int32_t value) {
-    ICHECK_LT(current_size_, reserved_size_);
+    ICHECK_LE(current_size_, reserved_size_);
+    if (current_size_ == reserved_size_) {
+      reserved_size_ *= 2;
+      NDArray new_data = NDArray::Empty({reserved_size_}, data_->dtype, 
data_->device);
+      std::memcpy(new_data->data, data_->data, current_size_ * 
DataType(data_->dtype).bytes());
+      data_ = new_data;
+    }
     static_cast<int32_t*>(data_->data)[current_size_++] = value;
   }
 
@@ -255,6 +270,15 @@ class HostMemoryVector {
   /*! \brief Return the vector as an NDArray. */
   NDArray as_ndarray() { return data_.CreateView({current_size_}, 
data_->dtype); }
 
+  IntTuple as_int_tuple() const {
+    std::vector<int64_t> values;
+    values.reserve(current_size_);
+    for (int i = 0; i < current_size_; ++i) {
+      values.push_back(static_cast<int32_t*>(data_->data)[i]);
+    }
+    return IntTuple(values);
+  }
+
  private:
   int64_t reserved_size_ = 0;
   int64_t current_size_ = 0;
@@ -331,6 +355,16 @@ class PagedKVCacheAuxDataManager {
    * appending new K/V data.
    */
   virtual NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) = 0;
+  /*! \brief Copy the remote position map for KV transfer. */
+  virtual NDArray CopyKVTransferRemotePositionMapAsync(HostMemoryVector* data) 
= 0;
+  /*! \brief Copy the receiver id for KV transfer. */
+  virtual NDArray CopyKVTransferRecverIDAsync(HostMemoryVector* data) = 0;
+  /*! \brief Copy the local position map for KV page-to-page transfer. */
+  virtual NDArray 
CopyKVTransferPage2PageLocalPositionMapAsync(HostMemoryVector* data) = 0;
+  /*! \brief Copy the remote position map for KV page-to-page transfer. */
+  virtual NDArray 
CopyKVTransferPage2PageRemotePositionMapAsync(HostMemoryVector* data) = 0;
+  /*! \brief Copy the receiver id for KV page-to-page transfer. */
+  virtual NDArray CopyKVTransferPage2PageRecverIDAsync(HostMemoryVector* data) 
= 0;
   /*! \brief Copy the tree attention mask. */
   virtual NDArray CopyTreeAttnMaskOnDepthAsync(HostMemoryVector* data, int 
depth) = 0;
   /*! \brief Copy the mn indptr of the tree attention mask. */
@@ -390,7 +424,14 @@ class PlainPagedKVCacheAuxDataManager : public 
PagedKVCacheAuxDataManager {
     k_ragged_rope_pos_offset_device_ = NDArray::Empty({reserved_num_seqs}, 
dtype_aux_, device);
     q_rope_position_map_device_ = NDArray::Empty({prefill_chunk_size}, 
dtype_aux_, device);
     append_position_map_device_ = NDArray::Empty({prefill_chunk_size}, 
dtype_aux_, device);
-
+    kv_transfer_remote_position_map_device =
+        NDArray::Empty({prefill_chunk_size}, dtype_aux_, device);
+    kv_transfer_recver_id_device = NDArray::Empty({prefill_chunk_size}, 
dtype_aux_, device);
+    kv_transfer_page_to_page_local_position_map_device =
+        kv_transfer_page_to_page_remote_position_map_device =
+            NDArray::Empty({prefill_chunk_size}, dtype_aux_, device);
+    kv_transfer_page_to_page_recver_id_device =
+        NDArray::Empty({prefill_chunk_size}, dtype_aux_, device);
     commit_copy_length_indptr_device_ = NDArray::Empty({reserved_num_seqs + 
1}, dtype_aux_, device);
     commit_copy_src_dst_pos_in_page_table_device_ =
         NDArray::Empty({2, std::min(kTreeAttnMaxTreeSize * reserved_num_seqs, 
prefill_chunk_size)},
@@ -453,6 +494,37 @@ class PlainPagedKVCacheAuxDataManager : public 
PagedKVCacheAuxDataManager {
     CopyVecDataToArray(view, data->data());
     return view;
   }
+  NDArray CopyKVTransferRemotePositionMapAsync(HostMemoryVector* data) final {
+    NDArray view = kv_transfer_remote_position_map_device.CreateView(
+        {static_cast<int64_t>(data->size())}, dtype_aux_);
+    CopyVecDataToArray(view, data->data());
+    return view;
+  }
+  NDArray CopyKVTransferRecverIDAsync(HostMemoryVector* data) final {
+    NDArray view =
+        
kv_transfer_recver_id_device.CreateView({static_cast<int64_t>(data->size())}, 
dtype_aux_);
+    CopyVecDataToArray(view, data->data());
+    return view;
+  }
+  NDArray CopyKVTransferPage2PageLocalPositionMapAsync(HostMemoryVector* data) 
final {
+    NDArray view = 
kv_transfer_page_to_page_local_position_map_device.CreateView(
+        {static_cast<int64_t>(data->size())}, dtype_aux_);
+    CopyVecDataToArray(view, data->data());
+    return view;
+  }
+  NDArray CopyKVTransferPage2PageRemotePositionMapAsync(HostMemoryVector* 
data) final {
+    NDArray view = 
kv_transfer_page_to_page_remote_position_map_device.CreateView(
+        {static_cast<int64_t>(data->size())}, dtype_aux_);
+    CopyVecDataToArray(view, data->data());
+    return view;
+  }
+  NDArray CopyKVTransferPage2PageRecverIDAsync(HostMemoryVector* data) final {
+    NDArray view = kv_transfer_page_to_page_recver_id_device.CreateView(
+        {static_cast<int64_t>(data->size())}, dtype_aux_);
+    CopyVecDataToArray(view, data->data());
+    return view;
+  }
+
   NDArray CopyTreeAttnMaskOnDepthAsync(HostMemoryVector* data, int depth) 
final {
     NDArray view =
         
tree_attn_mask_device_[depth].CreateView({static_cast<int64_t>(data->size())}, 
dtype_aux_);
@@ -566,6 +638,11 @@ class PlainPagedKVCacheAuxDataManager : public 
PagedKVCacheAuxDataManager {
   NDArray k_ragged_rope_pos_offset_device_;
   NDArray q_rope_position_map_device_;
   NDArray append_position_map_device_;
+  NDArray kv_transfer_remote_position_map_device;
+  NDArray kv_transfer_recver_id_device;
+  NDArray kv_transfer_page_to_page_local_position_map_device;
+  NDArray kv_transfer_page_to_page_remote_position_map_device;
+  NDArray kv_transfer_page_to_page_recver_id_device;
   NDArray commit_copy_length_indptr_device_;
   NDArray commit_copy_src_dst_pos_in_page_table_device_;
 };
@@ -633,6 +710,21 @@ class CachedPagedKVCacheAuxDataManager : public 
PagedKVCacheAuxDataManager {
   NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) final {
     return CopyAttnAuxVecToCache(data);
   }
+  NDArray CopyKVTransferRemotePositionMapAsync(HostMemoryVector* data) final {
+    return CopyAttnAuxVecToCache(data);
+  }
+  NDArray CopyKVTransferRecverIDAsync(HostMemoryVector* data) final {
+    return CopyAttnAuxVecToCache(data);
+  }
+  NDArray CopyKVTransferPage2PageLocalPositionMapAsync(HostMemoryVector* data) 
final {
+    return CopyAttnAuxVecToCache(data);
+  }
+  NDArray CopyKVTransferPage2PageRemotePositionMapAsync(HostMemoryVector* 
data) final {
+    return CopyAttnAuxVecToCache(data);
+  }
+  NDArray CopyKVTransferPage2PageRecverIDAsync(HostMemoryVector* data) final {
+    return CopyAttnAuxVecToCache(data);
+  }
   NDArray CopyTreeAttnMaskOnDepthAsync(HostMemoryVector* data, int depth) 
final {
     NDArray mask_1d = CopyAttnAuxVecToCache(data);
     return mask_1d.CreateView({static_cast<int64_t>(data->size() / 2), 2}, 
mask_1d->dtype);
@@ -736,12 +828,22 @@ class CachedPagedKVCacheAuxDataManager : public 
PagedKVCacheAuxDataManager {
     //  - k_ragged_rope_pos_offset
     //  - q_rope_position_map
     //  - append_position_map
+    //  - kv_transfer_remote_position_map
+    //  - kv_transfer_recver_id
+    //  - kv_transfer_page_to_page_local_position_map
+    //  - kv_transfer_page_to_page_remote_position_map
+    //  - kv_transfer_page_to_page_recver_id
     //  - tree_attn_mask
     //  - tree_attn_mn_indptr
     cache_size += CeilDivElemAlignment(reserved_num_seqs + 1);
     cache_size += CeilDivElemAlignment(reserved_num_seqs);
     cache_size += CeilDivElemAlignment(prefill_chunk_size);
     cache_size += CeilDivElemAlignment(prefill_chunk_size);
+    cache_size += CeilDivElemAlignment(prefill_chunk_size);
+    cache_size += CeilDivElemAlignment(prefill_chunk_size);
+    cache_size += CeilDivElemAlignment(prefill_chunk_size);
+    cache_size += CeilDivElemAlignment(prefill_chunk_size);
+    cache_size += CeilDivElemAlignment(prefill_chunk_size);
     cache_size +=
         CeilDivElemAlignment(kTreeAttnMaxTreeSize * kTreeAttnMaxTreeSize * 
reserved_num_seqs);
     cache_size += CeilDivElemAlignment(reserved_num_seqs + 1);
@@ -862,11 +964,15 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
 
   /*!
    * \brief The KV data managed by the KV cache.
-   * The array has `num_layers` NDArrays, each of them
+   * If KV transfer function is specifed, pages_ will be allocated by NVSHMEM 
as a whole NDArray.
+   * pages_ will contain tensor view of each layer.
+   * Otherwise, pages_ has `num_layers` NDArrays, each of them
    * has layout (num_pages, 2, num_heads, page_size, head_dim).
    * Along on the "2" dimension, index 0 stands for K and 1 stands for V.
    */
-  Array<NDArray> pages_;
+  std::vector<NDArray> pages_;
+  /*! \brief The whole KV cache allocated by NVSHMEM*/
+  NDArray nvshmem_pages_;
   /*! \brief The list of ids of released pages for page reuse. */
   std::vector<int32_t> free_page_ids_;
   /*! \brief The mapping from sequence ids to sequences. */
@@ -907,6 +1013,12 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
   std::vector<bool> use_decode_kernel_;
   /*! \brief Whether the attention request is a decode request, set in 
BeginForwardFunction. */
   bool is_decode_request_;
+  /*! \brief The KV transfer recver disco group's PE offset in this forward.
+             If no KV is transfered, recver is -1.
+             Assume that all the KV are transfered to the same recver in the 
forward.
+             todo: support multiple recver. */
+  bool transfer_kv_;
+  bool page_to_page_transfer_kv_;
   /*! \brief The auxiliary data manager for attention. */
   std::unique_ptr<PagedKVCacheAuxDataManager> aux_data_manager_;
 
@@ -940,6 +1052,11 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
   HostMemoryVector commit_copy_length_indptr_host_;
   HostMemoryVector commit_copy_src_pos_in_page_table_host_;
   HostMemoryVector commit_copy_dst_pos_in_page_table_host_;
+  HostMemoryVector kv_transfer_remote_position_map_host_;
+  HostMemoryVector kv_transfer_recver_id_host_;
+  HostMemoryVector kv_transfer_page_to_page_local_position_map_host_;
+  HostMemoryVector kv_transfer_page_to_page_remote_position_map_host_;
+  HostMemoryVector kv_transfer_page_to_page_recver_id_host_;
 
   //-------------------------------------------
   // For efficient memory management, the actual sizes of the arrays
@@ -952,6 +1069,11 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
   NDArray k_ragged_rope_pos_offset_view_;
   NDArray q_rope_position_map_view_;
   NDArray append_position_map_view_;
+  NDArray kv_transfer_remote_position_map_view_;
+  NDArray kv_transfer_recver_id_view_;
+  NDArray kv_transfer_page_to_page_local_position_map_view_;
+  NDArray kv_transfer_page_to_page_remote_position_map_view_;
+  NDArray kv_transfer_page_to_page_recver_id_view_;
   NDArray temp_attn_output_view_;
   NDArray temp_attn_scores_view_;
   NDArray merged_attn_scores_view_;
@@ -964,6 +1086,8 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
   std::vector<NDArray> tree_attn_mn_indptr_view_;
 
   PackedFunc f_transpose_append_;
+  Optional<PackedFunc> f_transfer_kv_;
+  Optional<PackedFunc> f_transfer_kv_page_to_page_ = NullOpt;
   PackedFunc f_compact_copy_;
   PackedFunc f_attention_prefill_;
   PackedFunc f_attention_decode_;
@@ -989,6 +1113,8 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
   TVMStreamHandle compute_stream_ = nullptr;
   /*! \brief The device stream for copying auxiliary data structure to GPU. */
   TVMStreamHandle copy_stream_ = nullptr;
+  /*! \brief The device stream for KV transfer */
+  TVMStreamHandle kv_transfer_stream_ = nullptr;
 
  public:
   /*! \brief Constructor. Take the cache configuration and initialize the 
NDArrays. */
@@ -997,7 +1123,7 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
       int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, int64_t 
reserved_num_seqs,
       int64_t num_total_pages, int64_t prefill_chunk_size, bool 
support_sliding_window,
       RoPEMode rope_mode, double rotary_scale, double rotary_theta,
-      Optional<NDArray> rope_ext_factors, DLDataType dtype, Device device,
+      Optional<NDArray> rope_ext_factors, bool enable_kv_transfer, DLDataType 
dtype, Device device,
       PackedFunc f_transpose_append, PackedFunc f_compact_copy, PackedFunc 
f_attention_prefill,
       PackedFunc f_attention_decode, PackedFunc 
f_attention_prefill_sliding_window,
       PackedFunc f_attention_decode_sliding_window, PackedFunc 
f_attention_prefill_ragged,
@@ -1047,10 +1173,35 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
         f_debug_get_kv_(std::move(f_debug_get_kv)),
         device_(device) {
     pages_.reserve(num_layers);
-    for (int i = 0; i < num_layers; ++i) {
-      pages_.push_back(
-          NDArray::Empty({num_total_pages, 2, num_kv_heads, page_size, 
head_dim}, dtype, device));
+    if (enable_kv_transfer) {
+      CHECK(Registry::Get("runtime.disco.nvshmem.init_nvshmem") != nullptr)
+          << "NVSHMEM is not enabled. Please make sure NVSHMEM is enabled when 
compiling TVM.";
+      const PackedFunc* f_nvshmem_empty = 
runtime::Registry::Get("runtime.disco.nvshmem.empty");
+      ICHECK_NOTNULL(f_nvshmem_empty);
+      nvshmem_pages_ = (*f_nvshmem_empty)(
+          ShapeTuple({num_layers, num_total_pages, 2, num_kv_heads, page_size, 
head_dim}), dtype,
+          device);
+      for (int i = 0; i < num_layers; ++i) {
+        pages_.push_back(nvshmem_pages_.CreateView(
+            {num_total_pages_, 2, num_kv_heads_, page_size_, head_dim_}, 
nvshmem_pages_->dtype,
+            i * num_total_pages_ * 2 * num_kv_heads_ * page_size_ * head_dim_ *
+                nvshmem_pages_.DataType().bytes()));
+      }
+
+      const PackedFunc* f_transfer_kv_ptr = 
Registry::Get("nvshmem.KVTransfer");
+      const PackedFunc* f_transfer_kv_page_to_page_ptr =
+          Registry::Get("nvshmem.KVTransferPageToPage");
+      ICHECK_NOTNULL(f_transfer_kv_ptr);
+      ICHECK_NOTNULL(f_transfer_kv_page_to_page_ptr);
+      f_transfer_kv_ = *f_transfer_kv_ptr;
+      f_transfer_kv_page_to_page_ = *f_transfer_kv_page_to_page_ptr;
+    } else {
+      for (int i = 0; i < num_layers; ++i) {
+        pages_.push_back(
+            NDArray::Empty({num_total_pages, 2, num_kv_heads, page_size, 
head_dim}, dtype, device));
+      }
     }
+
     // Allocate the host memory.
     Device preferred_host_device = GetPreferredHostDevice(device);
     for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) {
@@ -1079,6 +1230,16 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
         HostMemoryVector(prefill_chunk_size, dtype_aux_, 
preferred_host_device);
     append_position_map_host_ =
         HostMemoryVector(prefill_chunk_size, dtype_aux_, 
preferred_host_device);
+    kv_transfer_remote_position_map_host_ =
+        HostMemoryVector(prefill_chunk_size, dtype_aux_, 
preferred_host_device);
+    kv_transfer_recver_id_host_ =
+        HostMemoryVector(prefill_chunk_size, dtype_aux_, 
preferred_host_device);
+    kv_transfer_page_to_page_local_position_map_host_ =
+        HostMemoryVector(prefill_chunk_size, dtype_aux_, 
preferred_host_device);
+    kv_transfer_page_to_page_remote_position_map_host_ =
+        HostMemoryVector(prefill_chunk_size, dtype_aux_, 
preferred_host_device);
+    kv_transfer_page_to_page_recver_id_host_ =
+        HostMemoryVector(prefill_chunk_size, dtype_aux_, 
preferred_host_device);
     cur_append_lengths_indptr_host_ =
         HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, 
preferred_host_device);
     commit_copy_length_indptr_host_ =
@@ -1135,6 +1296,7 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
       // The compute stream is the default stream.
       compute_stream_ = DeviceAPI::Get(device)->GetCurrentStream(device);
       copy_stream_ = DeviceAPI::Get(device)->CreateStream(device);
+      kv_transfer_stream_ = DeviceAPI::Get(device)->CreateStream(device);
     }
 
     // Create the auxiliary data manager for attention.
@@ -1162,12 +1324,14 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     if (copy_stream_ != nullptr) {
       DeviceAPI::Get(device_)->FreeStream(device_, copy_stream_);
     }
+    if (kv_transfer_stream_ != nullptr) {
+      DeviceAPI::Get(device_)->FreeStream(device_, kv_transfer_stream_);
+    }
   }
 
   /*! \brief Reset the KV cache. */
   void Clear() final {
     seq_map_.clear();
-    ICHECK(pages_.defined());
     free_page_ids_.clear();
     for (int64_t page_id = num_total_pages_ - 1; page_id >= 0; --page_id) {
       free_page_ids_.push_back(page_id);
@@ -1331,7 +1495,8 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
       DeviceAPI::Get(device_)->SetStream(device_, copy_stream_);
     }
     for (int layer = 0; layer < num_layers_; ++layer) {
-      f_copy_single_page_(pages_[layer], src_page_id, tgt_page_id, 
copy_length);
+      NDArray page_layer_view = pages_[layer];
+      f_copy_single_page_(page_layer_view, src_page_id, tgt_page_id, 
copy_length);
     }
     if (copy_stream_ != compute_stream_) {
       // Set the compute stream back.
@@ -1677,6 +1842,13 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     // in the global KV cache. The mapping is used in when appending k/v 
values.
     q_rope_position_map_host_.clear();
     append_position_map_host_.clear();
+    kv_transfer_remote_position_map_host_.clear();
+    kv_transfer_recver_id_host_.clear();
+    kv_transfer_page_to_page_local_position_map_host_.clear();
+    kv_transfer_page_to_page_remote_position_map_host_.clear();
+    kv_transfer_page_to_page_recver_id_host_.clear();
+    transfer_kv_ = false;
+    page_to_page_transfer_kv_ = false;
     for (int i = 0; i < cur_batch_size_; ++i) {
       int64_t append_length = append_lengths[i];
       const Block& block = global_block_pool_[sequences[i]->last_block_idx];
@@ -1710,11 +1882,40 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
                                                   page_size_ +
                                               offset_in_block % page_size_);
         }
+        int64_t pos_in_seq = sequences[i]->seq_length - append_length + pos;
+        int64_t seq_send_start = sequences[i]->kv_transfer_metadata.start;
+        if (pos_in_seq < seq_send_start) {
+          kv_transfer_remote_position_map_host_.push_back(-1);
+          kv_transfer_recver_id_host_.push_back(-1);
+        } else {
+          transfer_kv_ = true;
+          kv_transfer_remote_position_map_host_.push_back(
+              
sequences[i]->kv_transfer_metadata.remote_position_map[pos_in_seq - 
seq_send_start]);
+          kv_transfer_recver_id_host_.push_back(
+              sequences[i]->kv_transfer_metadata.recver_pe_offset);
+        }
+      }
+      if (!sequences[i]->kv_transfer_metadata.local_position_map.empty()) {
+        page_to_page_transfer_kv_ = true;
+        for (int pos = 0;
+             pos < 
static_cast<int>(sequences[i]->kv_transfer_metadata.local_position_map.size());
+             ++pos) {
+          kv_transfer_page_to_page_local_position_map_host_.push_back(
+              sequences[i]->kv_transfer_metadata.local_position_map[pos]);
+          kv_transfer_page_to_page_remote_position_map_host_.push_back(
+              sequences[i]->kv_transfer_metadata.remote_position_map[pos]);
+          kv_transfer_page_to_page_recver_id_host_.push_back(
+              sequences[i]->kv_transfer_metadata.recver_pe_offset);
+        }
+        sequences[i]->kv_transfer_metadata.local_position_map.clear();
       }
     }
   }
 
   void EndForward() final {
+    if (kv_transfer_stream_ != nullptr) {
+      DeviceAPI::Get(device_)->SyncStreamFromTo(device_, kv_transfer_stream_, 
compute_stream_);
+    }
     if (!f_attention_prefill_end_forward_.defined() || 
!f_attention_decode_end_forward_.defined() ||
         !f_attention_prefill_ragged_end_forward_.defined()) {
       return;
@@ -1726,6 +1927,88 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     }
   }
 
+  IntTuple DisaggPrepareRecv(int64_t seq_id, int append_length) final {
+    // No CPU to GPU copy is needed.
+    // Essentially we
+    // (step 1.) redirect the preparation to BeginForward.
+    BeginForward({seq_id}, {append_length}, 
/*opt_token_tree_parent_ptr=*/NullOpt);
+    // (step 2.) fetch the append_position_map, compress and return.
+    // Compression format: [n, begin_1, length_1, begin_2, length_2, ..., 
begin_n, length_n]
+    // The compressed format will be decompressed to:
+    // [begin_1, begin_1+1, ..., begin_1+length_1-1, ..., begin_n, ..., 
begin_n+length_n-1]
+    CHECK_EQ(append_position_map_host_.size(), append_length);
+    std::vector<int64_t> compressed_append_pos_map{/*num_segments=*/1,
+                                                   
append_position_map_host_[0]};
+    for (int i = 1; i < append_length; ++i) {
+      if (append_position_map_host_[i] != append_position_map_host_[i - 1] + 
1) {
+        // Terminate the current segment.
+        compressed_append_pos_map.push_back(append_position_map_host_[i - 1] -
+                                            compressed_append_pos_map.back() + 
1);
+        // Start a new segment.
+        ++compressed_append_pos_map[0];
+        compressed_append_pos_map.push_back(append_position_map_host_[i]);
+      }
+    }
+    // Terminate the last segment.
+    compressed_append_pos_map.push_back(append_position_map_host_.back() -
+                                        compressed_append_pos_map.back() + 1);
+    // The compressed array size should be "num_segments * 2 + 1".
+    CHECK_EQ(compressed_append_pos_map.size(), compressed_append_pos_map[0] * 
2 + 1);
+    return IntTuple{compressed_append_pos_map};
+  }
+
+  void DisaggMarkSend(int64_t seq_id, int64_t begin, const IntTuple& 
compressed_remote_position_map,
+                      int32_t recver_pe_offset) {
+    ICHECK(f_transfer_kv_.defined());
+    auto it = seq_map_.find(seq_id);
+    CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot 
be found in KV cache.";
+    Sequence* sequence = &it->second;
+    sequence->kv_transfer_metadata.start = begin;
+    int nsegments = compressed_remote_position_map[0];
+    sequence->kv_transfer_metadata.remote_position_map.clear();
+    for (int i = 0; i < nsegments; ++i) {
+      int begin = compressed_remote_position_map[2 * i + 1];
+      int length = compressed_remote_position_map[2 * i + 2];
+      for (int j = 0; j < length; ++j) {
+        sequence->kv_transfer_metadata.remote_position_map.push_back(begin + 
j);
+      }
+    }
+    sequence->kv_transfer_metadata.recver_pe_offset = recver_pe_offset;
+
+    sequence->kv_transfer_metadata.local_position_map.clear();
+    if (begin >= sequence->seq_length) {
+      return;
+    }
+    // Need to send existing KV.
+    
CHECK_GT(static_cast<int>(sequence->kv_transfer_metadata.remote_position_map.size()),
+             sequence->seq_length - begin)
+        << "Need at least one token to prefill";
+    std::vector<int32_t> trace = sequence->GetBlockTrace(global_block_pool_);
+    
sequence->kv_transfer_metadata.local_position_map.reserve(sequence->seq_length 
- begin);
+    bool done = false;
+    for (auto it_block_id = trace.rbegin(); it_block_id != trace.rend(); 
++it_block_id) {
+      const Block& block = global_block_pool_[*it_block_id];
+      for (int i = block.seq_length - 1; i >= 0; --i) {
+        int32_t offset =
+            i < block.sink_length ? i : i - block.sink_length + 
block.sliding_window_offset;
+        int page_id = block.page_ids[offset / page_size_];
+        int page_offset = offset % page_size_;
+        sequence->kv_transfer_metadata.local_position_map.push_back(page_id * 
page_size_ +
+                                                                    
page_offset);
+        if 
(static_cast<int>(sequence->kv_transfer_metadata.local_position_map.size()) ==
+            sequence->seq_length - begin) {
+          done = true;
+          break;
+        }
+      }
+      if (done) {
+        break;
+      }
+    }
+    std::reverse(sequence->kv_transfer_metadata.local_position_map.begin(),
+                 sequence->kv_transfer_metadata.local_position_map.end());
+  }
+
   void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, 
Optional<NDArray> mask,
                              NDArray o_data, double attn_score_scaling_factor) 
final {
     // Part 1. Shape and dtype check.
@@ -1777,6 +2060,10 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
           o_data.CreateView({total_seq_length, num_qo_heads_, head_dim_}, 
qkv_data->dtype);
     }
     // Part 2. Split fused qkv and apply rotary embedding to q/k data.
+    if (transfer_kv_) {
+      // The the compute stream needs to wait for the KV transfer stream.
+      DeviceAPI::Get(device_)->SyncStreamFromTo(device_, kv_transfer_stream_, 
compute_stream_);
+    }
     if (!rope_ext_factors_.defined()) {
       f_split_rotary_(qkv_data_view, q_rope_position_map_view_, q_data, 
k_data, v_data,
                       static_cast<int>(rope_mode_ == RoPEMode::kNormal));
@@ -1789,9 +2076,30 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     if (append_before_attn_) {
       f_transpose_append_(pages_[local_layer_id], k_data, v_data, 
append_position_map_view_);
     }
-    // Part 4: perform attention
+    // Part 4: KV transfer
+    if (page_to_page_transfer_kv_) {
+      DeviceAPI::Get(device_)->SyncStreamFromTo(device_, copy_stream_, 
kv_transfer_stream_);
+      // FIXME: if the sender and recver's PP/TP degree do not match, we will 
need to first
+      // get the view of remote pages, and then take the specific remote layer.
+      // The KV transfer stream nees to wait for the compute stream.
+      f_transfer_kv_page_to_page_.value()(pages_[local_layer_id], 
pages_[local_layer_id],
+                                          
kv_transfer_page_to_page_remote_position_map_view_,
+                                          
kv_transfer_page_to_page_local_position_map_view_,
+                                          
kv_transfer_page_to_page_recver_id_view_,
+                                          kv_transfer_stream_);
+    }
+    if (transfer_kv_) {
+      // FIXME: if the sender and recver's PP/TP degree do not match, we will 
need to first
+      // get the view of remote pages, and then take the specific remote layer.
+      // The KV transfer stream nees to wait for the compute stream.
+      DeviceAPI::Get(device_)->SyncStreamFromTo(device_, compute_stream_, 
kv_transfer_stream_);
+      f_transfer_kv_.value()(pages_[local_layer_id], k_data, v_data,
+                             kv_transfer_remote_position_map_view_, 
kv_transfer_recver_id_view_,
+                             kv_transfer_stream_);
+    }
+    // Part 5: perform attention
     AttentionInternal(layer_id, q_data, k_data, v_data, o_data_view, 
attn_score_scaling_factor);
-    // Part 5. Append k/v data to kv-cache if flag "append_before_attn" is not 
set.
+    // Part 6. Append k/v data to kv-cache if flag "append_before_attn" is not 
set.
     if (!append_before_attn_) {
       f_transpose_append_(pages_[local_layer_id], k_data, v_data, 
append_position_map_view_);
     }
@@ -2491,6 +2799,8 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     }
     total_append_length = cur_append_lengths_indptr_host_.back();
     ICHECK_EQ(total_append_length, append_position_map_host_.size());
+    ICHECK_EQ(total_append_length, 
kv_transfer_remote_position_map_host_.size());
+    ICHECK_EQ(total_append_length, kv_transfer_recver_id_host_.size());
 
     // - Reset the copy.
     aux_data_manager_->ResetAttnAuxDataCopy();
@@ -2553,7 +2863,26 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     // 9. append_position_map
     append_position_map_view_ =
         
aux_data_manager_->CopyAppendPositionMapAsync(&append_position_map_host_);
-    // 10. tree_attn_mask and tree_attn_mn_indptr
+    // 10. kv_transfer_remote_position_map
+    kv_transfer_remote_position_map_view_ = 
aux_data_manager_->CopyKVTransferRemotePositionMapAsync(
+        &kv_transfer_remote_position_map_host_);
+    // 11. kv_transfer_recver_id
+    kv_transfer_recver_id_view_ =
+        
aux_data_manager_->CopyKVTransferRecverIDAsync(&kv_transfer_recver_id_host_);
+
+    // 12. kv_transfer_page_to_page_local_position_map
+    kv_transfer_page_to_page_local_position_map_view_ =
+        aux_data_manager_->CopyKVTransferPage2PageLocalPositionMapAsync(
+            &kv_transfer_page_to_page_local_position_map_host_);
+    // 13. kv_transfer_page_to_page_remote_position_map
+    kv_transfer_page_to_page_remote_position_map_view_ =
+        aux_data_manager_->CopyKVTransferPage2PageRemotePositionMapAsync(
+            &kv_transfer_page_to_page_remote_position_map_host_);
+    // 14. kv_transfer_page_to_page_recver_id
+    kv_transfer_page_to_page_recver_id_view_ =
+        aux_data_manager_->CopyKVTransferPage2PageRecverIDAsync(
+            &kv_transfer_page_to_page_recver_id_host_);
+    // 15. tree_attn_mask and tree_attn_mn_indptr
     for (int d = 0; d < num_depths_; ++d) {
       if (!is_chain_on_depths_[d]) {
         tree_attn_mask_view_[d] =
@@ -2562,7 +2891,7 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
             
aux_data_manager_->CopyTreeAttnMNIndptrOnDepthAsync(&tree_attn_mn_indptr_host_[d],
 d);
       }
     }
-    // 11. Create view for temporary arrays for attention computation.
+    // 16. Create view for temporary arrays for attention computation.
     temp_attn_output_view_ = temp_attn_output_device_.CreateView(
         {total_append_length, num_qo_heads_, head_dim_}, 
temp_attn_output_device_->dtype);
     temp_attn_scores_view_ = temp_attn_scores_device_.CreateView(
@@ -2585,7 +2914,7 @@ TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj);
 
 TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
     .set_body([](TVMArgs args, TVMRetValue* rv) {
-      CHECK(args.size() == 28 || args.size() == 29)
+      CHECK(args.size() == 29 || args.size() == 30)
           << "Invalid number of KV cache constructor args.";
       ShapeTuple cache_config = args[0];
       ShapeTuple layer_indptr_tuple = args[1];
@@ -2626,10 +2955,14 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
       PackedFunc f_attention_prefill_with_tree_mask = args[26];
       PackedFunc f_attention_prefill_with_tree_mask_paged_kv = args[27];
       Optional<NDArray> rope_ext_factors = NullOpt;
+      bool enable_kv_transfer = false;
 
-      if (args.size() >= 29 && args[28].IsObjectRef<NDArray>()) {
+      if (args[28].IsObjectRef<NDArray>()) {
         rope_ext_factors = args[28].AsObjectRef<NDArray>();
       }
+      if (args.size() >= 30) {
+        enable_kv_transfer = args[29];
+      }
 
       CHECK_EQ(cache_config.size(), 5);
       int64_t reserved_num_seqs = cache_config[0];
@@ -2646,9 +2979,9 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
           page_size, num_layers, layer_id_begin_offset, num_qo_heads, 
num_kv_heads, head_dim,
           reserved_num_seqs, num_total_pages, prefill_chunk_size, 
support_sliding_window,
           RoPEMode(rope_mode), rotary_scale, rotary_theta, 
std::move(rope_ext_factors),  //
-          init->dtype, init->device, std::move(f_transpose_append), 
std::move(f_compact_copy),
-          std::move(f_attention_prefill), std::move(f_attention_decode),
-          std::move(f_attention_prefill_sliding_window),
+          enable_kv_transfer, init->dtype, init->device,                       
          //
+          std::move(f_transpose_append), std::move(f_compact_copy), 
std::move(f_attention_prefill),
+          std::move(f_attention_decode), 
std::move(f_attention_prefill_sliding_window),
           std::move(f_attention_decode_sliding_window), 
std::move(f_attention_prefill_ragged),
           std::move(f_attention_prefill_with_tree_mask),
           std::move(f_attention_prefill_with_tree_mask_paged_kv),
@@ -2663,7 +2996,7 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
 
 TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
     .set_body([](TVMArgs args, TVMRetValue* rv) {
-      CHECK(args.size() == 22 || args.size() == 23)
+      CHECK(args.size() == 23 || args.size() == 24)
           << "Invalid number of KV cache constructor args.";
       ShapeTuple cache_config = args[0];
       ShapeTuple layer_indptr_tuple = args[1];
@@ -2698,10 +3031,14 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
       PackedFunc f_attention_prefill_with_tree_mask = args[20];
       PackedFunc f_attention_prefill_with_tree_mask_paged_kv = args[21];
       Optional<NDArray> rope_ext_factors = NullOpt;
+      bool enable_kv_transfer = false;
 
-      if (args.size() >= 23 && args[22].IsObjectRef<NDArray>()) {
+      if (args[22].IsObjectRef<NDArray>()) {
         rope_ext_factors = args[22].AsObjectRef<NDArray>();
       }
+      if (args.size() >= 24) {
+        enable_kv_transfer = args[23];
+      }
 
       CHECK_EQ(cache_config.size(), 5);
       int64_t reserved_num_seqs = cache_config[0];
@@ -2718,9 +3055,9 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
           page_size, num_layers, layer_id_begin_offset, num_qo_heads, 
num_kv_heads, head_dim,
           reserved_num_seqs, num_total_pages, prefill_chunk_size, 
support_sliding_window,
           RoPEMode(rope_mode), rotary_scale, rotary_theta, 
std::move(rope_ext_factors),  //
-          init->dtype, init->device, std::move(f_transpose_append), 
std::move(f_compact_copy),
-          std::move(f_attention_prefill), std::move(f_attention_decode),
-          std::move(f_attention_prefill_sliding_window),
+          enable_kv_transfer, init->dtype, init->device,                       
          //
+          std::move(f_transpose_append), std::move(f_compact_copy), 
std::move(f_attention_prefill),
+          std::move(f_attention_decode), 
std::move(f_attention_prefill_sliding_window),
           std::move(f_attention_decode_sliding_window), 
std::move(f_attention_prefill_ragged),
           std::move(f_attention_prefill_with_tree_mask),           //
           std::move(f_attention_prefill_with_tree_mask_paged_kv),  //
diff --git a/tests/python/disco/test_nvshmem.py 
b/tests/python/disco/test_nvshmem.py
index b304d145aa..1c4ffc9c4d 100644
--- a/tests/python/disco/test_nvshmem.py
+++ b/tests/python/disco/test_nvshmem.py
@@ -107,7 +107,7 @@ def test_nvshmem_init_finalize(session_kind: di.Session, 
num_workers: int):
     f_init_nvshmem_uid = 
tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
     uid = f_init_nvshmem_uid()
     init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem")
-    init_dfunc(uid, num_workers)
+    init_dfunc(uid, num_workers, 0)
     sess.sync_worker_0()
     finalize_dfunc = 
sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
     finalize_dfunc()
@@ -123,7 +123,7 @@ def test_nvshmem_empty(session_kind: di.Session, 
num_workers: int):
     f_init_nvshmem_uid = 
tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
     uid = f_init_nvshmem_uid()
     init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem")
-    init_dfunc(uid, num_workers)
+    init_dfunc(uid, num_workers, 0)
     sess.sync_worker_0()
     empty_dfunc = sess.get_global_func("runtime.disco.nvshmem.empty")
     a = empty_dfunc(ShapeTuple((32, 64)), "float32", device)
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py 
b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py
similarity index 57%
copy from 
tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
copy to tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py
index 82f85f4b17..483108ca83 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
+++ b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py
@@ -40,6 +40,20 @@ from tvm.relax.frontend.nn.llm.kv_cache import (
 )
 from tvm.runtime import ShapeTuple
 
+
+def get_comm_rank():
+    try:
+        from mpi4py import MPI
+
+        comm = MPI.COMM_WORLD
+        rank = comm.Get_rank()
+        return comm, rank
+    except ImportError:
+        return None, 0
+
+
+comm, rank = get_comm_rank()
+
 reserved_nseq = 32
 maximum_total_seq_length = 2048
 prefill_chunk_size = 512
@@ -52,7 +66,7 @@ rope_scale = 1.0
 rope_theta = 1e4
 rope_scaling = {}
 dtype = None
-device = tvm.cuda()
+device = tvm.cuda(rank)
 
 fclear = None
 fadd_sequence = None
@@ -66,6 +80,10 @@ fcommit_accepted_token_tree_nodes = None
 fattention_with_fuse_qkv = None
 fis_empty = None
 fdebug_get_kv = None
+fnvshmem_get_uid = None
+fnvshmem_init = None
+fdisagg_mark_send = None
+fdisagg_prepare_recv = None
 
 ftranspose_append = None
 fcopy_cache = None
@@ -91,6 +109,7 @@ def set_global_func(head_dim, dtype):
     global fattn_prefill_ragged, fattn_prefill_with_tree_mask, 
fattn_prefill_with_tree_mask_paged_kv_cache
     global fattn_prefill_sliding_window, fattn_decode_sliding_window
     global fmerge_state, fsplit_rotary, fattention_rotary, fcopy_single_page, 
fcompact_copy
+    global fnvshmem_get_uid, fnvshmem_init, fdisagg_mark_send, 
fdisagg_prepare_recv
 
     fclear = tvm.get_global_func("vm.builtin.kv_state_clear")
     fadd_sequence = tvm.get_global_func("vm.builtin.kv_state_add_sequence")
@@ -111,6 +130,11 @@ def set_global_func(head_dim, dtype):
     fis_empty = tvm.get_global_func("vm.builtin.attention_kv_cache_empty")
     fdebug_get_kv = 
tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv")
 
+    fnvshmem_get_uid = 
tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
+    fnvshmem_init = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem")
+    fdisagg_mark_send = 
tvm.get_global_func("vm.builtin.kv_cache_disagg_mark_send")
+    fdisagg_prepare_recv = 
tvm.get_global_func("vm.builtin.kv_cache_disagg_prepare_recv")
+
     target = tvm.target.Target.from_device(device)
     builts = []
     for tir_func in [
@@ -193,6 +217,7 @@ def create_kv_cache(head_dim, dtype, rope_mode, 
support_sliding_window):
         fattn_prefill_with_tree_mask,
         fattn_prefill_with_tree_mask_paged_kv_cache,
         None,
+        True,
     )
     return cache
 
@@ -277,6 +302,8 @@ def apply_attention(
     attn_sink_sizes: Optional[List[int]] = None,
     token_tree_parent_ptr_list: Optional[List[List[int]]] = None,
     accepted_leaf_indices: Optional[List[int]] = None,
+    only_update_host=False,
+    skip_add_sequence=False,
 ) -> None:
     seq_ids = []
     append_lengths = []
@@ -291,7 +318,8 @@ def apply_attention(
         if fork_parent_id is not None:
             assert fork_parent_id in cached_k
             assert seq_id not in cached_k
-            ffork_sequence(kv_cache, fork_parent_id, seq_id, fork_pos)
+            if not only_update_host:
+                ffork_sequence(kv_cache, fork_parent_id, seq_id, fork_pos)
             if fork_pos == -1:
                 cached_k[seq_id] = cached_k[fork_parent_id]
                 cached_v[seq_id] = cached_v[fork_parent_id]
@@ -299,7 +327,8 @@ def apply_attention(
                 cached_k[seq_id] = cached_k[fork_parent_id][::, :fork_pos]
                 cached_v[seq_id] = cached_v[fork_parent_id][::, :fork_pos]
         elif seq_id not in cached_k:
-            fadd_sequence(kv_cache, seq_id)
+            if not only_update_host and not skip_add_sequence:
+                fadd_sequence(kv_cache, seq_id)
             cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, 
head_dim), dtype)
             cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads, 
head_dim), dtype)
 
@@ -324,17 +353,17 @@ def apply_attention(
                 )
             # depth of each node in the tree (this contains more than the last 
`append_length` nodes)
             token_tree_node_depths_list[i] = token_tree_node_depths
-
-    fbegin_forward(
-        kv_cache,
-        ShapeTuple(seq_ids),
-        ShapeTuple(append_lengths),
-        (
-            ShapeTuple(flattened_token_tree_parent_ptr)
-            if flattened_token_tree_parent_ptr is not None
-            else None
-        ),
-    )
+    if not only_update_host:
+        fbegin_forward(
+            kv_cache,
+            ShapeTuple(seq_ids),
+            ShapeTuple(append_lengths),
+            (
+                ShapeTuple(flattened_token_tree_parent_ptr)
+                if flattened_token_tree_parent_ptr is not None
+                else None
+            ),
+        )
 
     global_new_q = np.zeros((num_layers, 0, num_qo_heads, head_dim), dtype)
     global_new_k = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype)
@@ -365,9 +394,11 @@ def apply_attention(
                                 rope_offset,
                                 rope_scale,
                                 rope_theta,
-                                token_tree_node_depths_list[i][-append_length:]
-                                if token_tree_node_depths_list[i] is not None
-                                else None,
+                                (
+                                    
token_tree_node_depths_list[i][-append_length:]
+                                    if token_tree_node_depths_list[i] is not 
None
+                                    else None
+                                ),
                             )
                         )
                         for l in range(num_layers)
@@ -388,7 +419,8 @@ def apply_attention(
         values_np = global_new_v[layer_id]
         qkv = tvm.nd.array(np.concatenate([queries_np, keys_np, values_np], 
axis=1), device)
         outputs = tvm.nd.empty(queries_np.shape, dtype, device=device)
-        fattention_with_fuse_qkv(kv_cache, layer_id, 1.0, qkv, outputs)
+        if not only_update_host:
+            fattention_with_fuse_qkv(kv_cache, layer_id, 1.0, qkv, outputs)
 
         # Compute attention expected results.
         outputs = np.expand_dims(outputs.numpy(), axis=0)
@@ -409,9 +441,11 @@ def apply_attention(
                     rope_offset,
                     rope_scale,
                     rope_theta,
-                    token_tree_node_depths_list[i][-append_length:]
-                    if token_tree_node_depths_list[i] is not None
-                    else None,
+                    (
+                        token_tree_node_depths_list[i][-append_length:]
+                        if token_tree_node_depths_list[i] is not None
+                        else None
+                    ),
                 )
             ).transpose(1, 0, 2)
             k_seq = (
@@ -464,21 +498,23 @@ def apply_attention(
                 ),
                 axis=0,
             ).astype(dtype)
-
-            tvm.testing.assert_allclose(
-                outputs[:, sum_length : sum_length + append_length, ...],
-                results,
-                rtol=1e-3,
-                atol=1e-3,
-            )
+            if not only_update_host:
+                tvm.testing.assert_allclose(
+                    outputs[:, sum_length : sum_length + append_length, ...],
+                    results,
+                    rtol=1e-3,
+                    atol=1e-3,
+                )
             sum_length += append_length
-    fend_forward(kv_cache)
+    if not only_update_host:
+        fend_forward(kv_cache)
 
     if accepted_leaf_indices is not None:
         seq_ids = [seq_id for seq_id, _ in batch]
-        fcommit_accepted_token_tree_nodes(
-            kv_cache, ShapeTuple(seq_ids), ShapeTuple(accepted_leaf_indices)
-        )
+        if not only_update_host:
+            fcommit_accepted_token_tree_nodes(
+                kv_cache, ShapeTuple(seq_ids), 
ShapeTuple(accepted_leaf_indices)
+            )
         for i, (accepted_leaf_idx, (seq_id, append_length)) in enumerate(
             zip(accepted_leaf_indices, batch)
         ):
@@ -530,11 +566,11 @@ def apply_attention(
                 assert cached_k[seq_id].shape[1] == sliding_window_size
 
     # Verify
-    verify_cached_kv(kv_cache, seq_ids, cached_k, cached_v)
+    if not only_update_host:
+        verify_cached_kv(kv_cache, seq_ids, cached_k, cached_v)
 
 
[email protected]_gpu
[email protected]_cuda
[email protected](reason="Require NVSHMEM")
 def test_paged_attention_kv_cache_prefill_and_decode(kv_cache_and_config):
     kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
     if support_sliding_window and rope_mode == RopeMode.NORMAL:
@@ -558,413 +594,92 @@ def 
test_paged_attention_kv_cache_prefill_and_decode(kv_cache_and_config):
         apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
 
 
[email protected]_gpu
[email protected]_cuda
-def test_paged_attention_kv_cache_remove_sequence(kv_cache_and_config):
-    kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
-    if support_sliding_window and rope_mode == RopeMode.NORMAL:
-        # Normal RoPE mode under sliding window settings is not supported.
-        return
-    fclear(kv_cache)
-
-    num_sequences = 5
-    batch = [(seq_id, 1) for seq_id in range(num_sequences)]
-    cached_k = {}
-    cached_v = {}
-    for seq_id_to_remove in range(num_sequences):
-        apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
-        # Remove sequence.
-        fremove_sequence(kv_cache, seq_id_to_remove)
-        cached_k.pop(seq_id_to_remove)
-        cached_v.pop(seq_id_to_remove)
-        verify_cached_kv(
-            kv_cache,
-            seq_ids=[seq_id for seq_id in range(num_sequences) if seq_id != 
seq_id_to_remove],
-            expected_k=cached_k,
-            expected_v=cached_v,
-        )
-
-
[email protected]_gpu
[email protected]_cuda
-def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config):
[email protected](reason="Require NVSHMEM")
+def test_paged_attention_kv_cache_transfer(kv_cache_and_config):
     kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
-    if support_sliding_window and rope_mode == RopeMode.NORMAL:
-        # Normal RoPE mode under sliding window settings is not supported.
-        return
-    fclear(kv_cache)
-
-    cached_k = {}
-    cached_v = {}
-    batch = [(0, 60), (1, 88), (2, 17), (3, 4)]
-    apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
-    # Fork existing sequences.
-    apply_attention(kv_cache, rope_mode, [((4, 3, -1), 35)], cached_k, 
cached_v)
-    apply_attention(kv_cache, rope_mode, [((5, 0, -1), 20)], cached_k, 
cached_v)
-    apply_attention(kv_cache, rope_mode, [((6, 5, -1), 102)], cached_k, 
cached_v)
-    apply_attention(kv_cache, rope_mode, [((7, 0, -1), 3)], cached_k, cached_v)
-    apply_attention(kv_cache, rope_mode, [((8, 5, -1), 71), ((9, 5, -1), 20)], 
cached_k, cached_v)
-    # 0 <- 5 <- 6,8,9
-    # 0 <- 7
-    # 3 <- 4
-    # Mixture of decode and prefill.
-    operation_seq = [
-        [(2, 1), (4, 1), (7, 1), (6, 1), (8, 1), (9, 1)],
-        [(7, 1), (6, 1), (8, 1), (9, 1)],
-        [(7, 1), (1, 1), (6, 1), (2, 1), (8, 1), (4, 1), (9, 1)],
-        [(7, 10), (6, 2), (8, 3), (9, 4)],
-    ]
-    for batch in operation_seq:
-        apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
-
-    apply_attention(kv_cache, rope_mode, [((10, 1, 33), 11)], cached_k, 
cached_v)
-    apply_attention(kv_cache, rope_mode, [((11, 0, 60), 45), ((12, 0, 15), 
14)], cached_k, cached_v)
-    apply_attention(kv_cache, rope_mode, [((13, 0, 16), 19), ((14, 0, 17), 
19)], cached_k, cached_v)
-    apply_attention(kv_cache, rope_mode, [((15, 5, 60), 8), ((16, 5, 80), 
10)], cached_k, cached_v)
-    apply_attention(
-        kv_cache,
-        rope_mode,
-        [((17, 5, 75), 11), ((18, 5, 76), 45), ((19, 5, 77), 14)],
-        cached_k,
-        cached_v,
-    )
-
-    operation_seq = [
-        [(6, 1), (11, 1), (13, 1), (9, 1)],
-        [(10, 1), (16, 1), (18, 1), (19, 1)],
-        [(8, 1), (15, 1), (17, 1), (12, 1), (14, 1)],
-        [(10, 10), (6, 2), (8, 3), (19, 4)],
-    ]
-    for batch in operation_seq:
-        apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
-
-    num_sequence = 20
-    for i in range(num_sequence):
-        fremove_sequence(kv_cache, i)
-        cached_k.pop(i)
-        cached_v.pop(i)
-        verify_cached_kv(
-            kv_cache,
-            seq_ids=list(range(i + 1, num_sequence)),
-            expected_k=cached_k,
-            expected_v=cached_v,
-        )
-
-    assert fis_empty(kv_cache), "The KV cache is not empty after removing all 
sequences"
-
-    # Test fork after page recycle
-    apply_attention(kv_cache, rope_mode, [(0, 7), (1, 24)], cached_k, cached_v)
-    apply_attention(kv_cache, rope_mode, [((2, 1, -1), 10)], cached_k, 
cached_v)
-    apply_attention(kv_cache, rope_mode, [((3, 0, -1), 20)], cached_k, 
cached_v)
-    apply_attention(kv_cache, rope_mode, [(2, 1), (3, 1)], cached_k, cached_v)
-
-    apply_attention(kv_cache, rope_mode, [(10, 7), (11, 24)], cached_k, 
cached_v)
-    apply_attention(kv_cache, rope_mode, [((12, 11, -1), 200)], cached_k, 
cached_v)
-    apply_attention(kv_cache, rope_mode, [(10, 1), (12, 1)], cached_k, 
cached_v)
-
-
[email protected]_gpu
[email protected]_cuda
-def test_paged_attention_kv_cache_unlimited_depth(kv_cache_and_config):
-    kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
-    if support_sliding_window and rope_mode == RopeMode.NORMAL:
+    if support_sliding_window:
         # Normal RoPE mode under sliding window settings is not supported.
         return
+    np.random.seed(0)
     fclear(kv_cache)
-
-    cached_k = {}
-    cached_v = {}
-    apply_attention(kv_cache, rope_mode, [(0, 30)], cached_k, cached_v)
-    # Fork existing sequences.
-    apply_attention(kv_cache, rope_mode, [((1, 0, -1), 15)], cached_k, 
cached_v)
-    apply_attention(kv_cache, rope_mode, [((2, 1, -1), 5)], cached_k, cached_v)
-    apply_attention(kv_cache, rope_mode, [((3, 2, -1), 20)], cached_k, 
cached_v)
-    apply_attention(kv_cache, rope_mode, [((4, 3, -1), 26)], cached_k, 
cached_v)
-    apply_attention(kv_cache, rope_mode, [((5, 3, -1), 18)], cached_k, 
cached_v)
-    apply_attention(kv_cache, rope_mode, [((6, 5, -1), 22)], cached_k, 
cached_v)
-    apply_attention(kv_cache, rope_mode, [((7, 5, -1), 12)], cached_k, 
cached_v)
-    apply_attention(kv_cache, rope_mode, [((8, 7, -1), 29)], cached_k, 
cached_v)
-    apply_attention(kv_cache, rope_mode, [((9, 7, -1), 9)], cached_k, cached_v)
-    apply_attention(kv_cache, rope_mode, [((10, 9, -1), 31)], cached_k, 
cached_v)
-    apply_attention(kv_cache, rope_mode, [((11, 9, -1), 4)], cached_k, 
cached_v)
-    # 0 <- 1 <- 2 <- 3 <- 5 <- 7 <- 9 <- 11
-    #                |    |    |    |
-    #                4    6    8    10
-    # Decode.
-    operation_seq = [
-        [(3, 1), (6, 1), (9, 1)],
-        [(4, 1), (8, 1), (10, 1)],
-        [(5, 1), (7, 1), (11, 1)],
-    ]
-    for batch in operation_seq:
-        apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
-
-    num_sequence = 12
-    for i in range(num_sequence):
-        fremove_sequence(kv_cache, i)
-        cached_k.pop(i)
-        cached_v.pop(i)
-        verify_cached_kv(
-            kv_cache,
-            seq_ids=list(range(i + 1, num_sequence)),
-            expected_k=cached_k,
-            expected_v=cached_v,
-        )
-
-    assert fis_empty(kv_cache), "The KV cache is not empty after removing all 
sequences"
-
-
[email protected]_gpu
[email protected]_cuda
-def test_paged_attention_kv_cache_popn(kv_cache_and_config):
-    kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
-    if support_sliding_window and rope_mode == RopeMode.NORMAL:
-        return
-    fclear(kv_cache)
-
-    cached_k = {}
-    cached_v = {}
-    batch = [(0, 35), (1, 88), (2, 17), (3, 4)]
-    apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
-    apply_attention(kv_cache, rope_mode, [((4, 3, -1), 35)], cached_k, 
cached_v)
-
-    popn_operations = [(0, 17), (1, 57), (2, 16), (3, 0), (4, 37)]
-    for seq_id, pop_length in popn_operations:
-        fpopn(kv_cache, seq_id, pop_length)
-        if pop_length != 0:
-            cached_k[seq_id] = cached_k[seq_id][:, :-pop_length, ...]
-            cached_v[seq_id] = cached_v[seq_id][:, :-pop_length, ...]
-        verify_cached_kv(kv_cache, seq_ids=list(range(4)), 
expected_k=cached_k, expected_v=cached_v)
-
-    num_sequence = 5
-    for seq_id in range(num_sequence):
-        fremove_sequence(kv_cache, seq_id)
-        verify_cached_kv(
-            kv_cache,
-            seq_ids=list(range(seq_id + 1, num_sequence)),
-            expected_k=cached_k,
-            expected_v=cached_v,
-        )
-
-    assert fis_empty(kv_cache), "The KV cache is not empty after removing all 
sequences"
-
-
[email protected]_gpu
[email protected]_cuda
-def test_paged_attention_kv_cache_sliding_window(kv_cache_and_config):
-    kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
-    if not support_sliding_window or rope_mode == RopeMode.NORMAL:
-        return
-    fclear(kv_cache)
-
-    cached_k = {}
-    cached_v = {}
-    sliding_window_sizes = [20, 25, 30, 35, 40]
-    attn_sink_sizes = [6, 4, 8, 3, 7]
-    for seq_id, (sliding_window_size, attn_sink_size) in enumerate(
-        zip(sliding_window_sizes, attn_sink_sizes)
-    ):
-        fadd_sequence(kv_cache, seq_id)
-        fenable_sliding_window_for_seq(kv_cache, seq_id, sliding_window_size, 
attn_sink_size)
-        cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), 
dtype)
-        cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), 
dtype)
-
     # Prefill.
-    operation_seq = [[(0, 4)], [(1, 6)], [(2, 6), (3, 7), (4, 7)]]
-    operation_seq += [[(0, 20), (1, 19), (2, 30), (3, 35), (4, 40)]]
-    operation_seq += [[(0, 6), (1, 5), (2, 4), (3, 3), (4, 2)]]
-    for batch in operation_seq:
-        apply_attention(
-            kv_cache,
-            rope_mode,
-            batch,
-            cached_k,
-            cached_v,
-            sliding_window_sizes,
-            attn_sink_sizes,
-        )
+    prefill_operation_seq = [[(0, 6)], [(1, 8)], [(2, 11)], [(3, 16)], [(4, 
19), (5, 20)]]
+    prefill_operation_seq += [[(6, 21), (7, 24)], [(2, 5), (4, 7), (8, 24)]]
+    prefill_operation_seq += [[(6, 13)], [(8, 19)], [(0, 1)], [(1, 3), (3, 8), 
(5, 12), (7, 11)]]
+    prefill_len = {i: 0 for i in range(9)}
+    for batch in prefill_operation_seq:
+        for seq_id, append_length in batch:
+            prefill_len[seq_id] += append_length
     # Decode
-    batch = [(0, 1), (1, 1), (2, 1), (3, 1), (4, 1)]
-    for _ in range(20):
-        apply_attention(
-            kv_cache,
-            rope_mode,
-            batch,
-            cached_k,
-            cached_v,
-            sliding_window_sizes,
-            attn_sink_sizes,
-        )
-
-
[email protected]_gpu
[email protected]_cuda
-def test_paged_attention_kv_cache_sliding_window_fork(kv_cache_and_config):
-    kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
-    if not support_sliding_window or rope_mode == RopeMode.NORMAL:
-        return
-    fclear(kv_cache)
-
-    cached_k = {}
-    cached_v = {}
-    sliding_window_sizes = [30, 35, 40]
-    attn_sink_sizes = [15, 20, 25]
-    for seq_id, (sliding_window_size, attn_sink_size) in enumerate(
-        zip(sliding_window_sizes, attn_sink_sizes)
-    ):
-        fadd_sequence(kv_cache, seq_id)
-        fenable_sliding_window_for_seq(kv_cache, seq_id, sliding_window_size, 
attn_sink_size)
-        cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), 
dtype)
-        cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), 
dtype)
-    apply_attention(
-        kv_cache,
-        rope_mode,
-        [(0, 12), (1, 18), (2, 28)],
-        cached_k,
-        cached_v,
-        sliding_window_sizes,
-        attn_sink_sizes,
-    )
-    # seq_len: [12, 18, 25+3]
-    sliding_window_sizes += [0, 0, 0]
-    attn_sink_sizes += [0, 0, 0]
-    apply_attention(
-        kv_cache,
-        rope_mode,
-        [((3, 0, 10), 8), ((4, 1, -1), 20), ((5, 2, 18), 18)],
-        cached_k,
-        cached_v,
-        sliding_window_sizes,
-        attn_sink_sizes,
-    )
-    # seq_len: [12, 18, 25+3, 18, 38, 36]
-    apply_attention(
-        kv_cache,
-        rope_mode,
-        [(0, 9), (1, 15), (2, 4), (3, 10), (4, 3), (5, 7)],
-        cached_k,
-        cached_v,
-        sliding_window_sizes,
-        attn_sink_sizes,
-    )
-    # seq_len: [15+6, 20+13, 25+7, 28, 41, 43]
-    sliding_window_sizes += [25]
-    attn_sink_sizes += [24]
-    ffork_sequence(kv_cache, 3, 6, 18)
-    fenable_sliding_window_for_seq(kv_cache, 6, sliding_window_sizes[-1], 
attn_sink_sizes[-1])
-    cached_k[6] = cached_k[3][::, :18]
-    cached_v[6] = cached_v[3][::, :18]
-    apply_attention(
-        kv_cache,
-        rope_mode,
-        [(3, 10), (6, 12)],
-        cached_k,
-        cached_v,
-        sliding_window_sizes,
-        attn_sink_sizes,
-    )
-    # seq_len: [15+6, 20+13, 25+7, 38, 41, 43, 24+6]
-
-
[email protected]_gpu
[email protected]_cuda
-def test_paged_attention_kv_cache_tree_attn(kv_cache_and_config):
-    kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
-    if support_sliding_window:
-        # Normal RoPE mode under sliding window settings is not supported.
-        return
-    if rope_mode == RopeMode.INLINE:
-        # Inline RoPE mode is not supported for tree attention.
-        return
-    fclear(kv_cache)
+    decode_operation_seq = [
+        [(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1), (7, 1), (8, 
1)]
+    ]
+    decode_operation_seq += [
+        [(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1), (7, 1), (8, 
1)]
+    ]
+    decode_operation_seq += [[(0, 1), (2, 1), (4, 1), (6, 1), (8, 1)]]
+    decode_operation_seq += [[(4, 1), (5, 1), (6, 1), (7, 1), (8, 1)]]
 
     cached_k = {}
     cached_v = {}
-    # Prefill 4 sequences
-    apply_attention(kv_cache, rope_mode, [(0, 10), (1, 20), (2, 30), (3, 40)], 
cached_k, cached_v)
-    # Tree attention
-    apply_attention(
-        kv_cache,
-        rope_mode,
-        [(0, 7), (1, 15), (2, 10), (3, 14)],
-        cached_k,
-        cached_v,
-        token_tree_parent_ptr_list=[
-            [-1, 0, 0, 1, 1, 2, 2],  # complete binary tree of height 3
-            [-1, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6],  # complete binary 
tree of height 4
-            [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8],  # chain of length 10
-            [-1, 0, 0, 1, 1, 2, 2, -1, 7, 7, 8, 8, 9, 9],  # two complete 
binary trees of height 3
-        ],
-        accepted_leaf_indices=[6, 11, 6, 13],
-    )
-    # Do 5 rounds of decode.
-    for _ in range(5):
-        apply_attention(kv_cache, rope_mode, [(0, 1), (1, 1), (2, 1), (3, 1)], 
cached_k, cached_v)
+    if rank == 0:
+        for seq_id, _ in prefill_len.items():
+            fadd_sequence(kv_cache, seq_id)
+        remote_pos_maps = None
+        remote_pos_maps = comm.bcast(remote_pos_maps, root=1)
+        comm.Barrier()
+        for seq_id in prefill_len.keys():
+            fdisagg_mark_send(kv_cache, seq_id, 0, 
ShapeTuple(remote_pos_maps[seq_id]), 1)
+        for batch in prefill_operation_seq:
+            apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v, 
skip_add_sequence=True)
+        device.sync()
+        comm.Barrier()
+    else:
+        remote_pos_maps = []
+        for seq_id, len in prefill_len.items():
+            fadd_sequence(kv_cache, seq_id)
+            compressed_pos_map = list(fdisagg_prepare_recv(kv_cache, seq_id, 
len))
+            remote_pos_maps.append(compressed_pos_map)
+        remote_pos_maps = comm.bcast(remote_pos_maps, root=1)
+        comm.Barrier()
+        for batch in prefill_operation_seq:
+            apply_attention(
+                kv_cache,
+                rope_mode,
+                batch,
+                cached_k,
+                cached_v,
+                only_update_host=True,
+                skip_add_sequence=True,
+            )
+        comm.Barrier()
+        for batch in decode_operation_seq:
+            apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v, 
skip_add_sequence=True)
 
-    # Test the cases where all trees are chains.
-    fclear(kv_cache)
-    cached_k = {}
-    cached_v = {}
-    # Prefill 4 sequences
-    apply_attention(kv_cache, rope_mode, [(0, 10), (1, 20), (2, 30), (3, 40)], 
cached_k, cached_v)
-    # Tree attention
-    apply_attention(
-        kv_cache,
-        rope_mode,
-        [(0, 7), (1, 15), (2, 10), (3, 14)],
-        cached_k,
-        cached_v,
-        token_tree_parent_ptr_list=[
-            [-1, 0, 1, 2, 3, 4, 5],  # complete binary tree of height 7
-            [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],  # chain of 
length 15
-            [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8],  # chain of length 10
-            [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],  # chain of length 
14
-        ],
-        accepted_leaf_indices=[2, 6, -1, 4],
-    )
-    # Do 5 rounds of decode.
-    for _ in range(5):
-        apply_attention(kv_cache, rope_mode, [(0, 1), (1, 1), (2, 1), (3, 1)], 
cached_k, cached_v)
 
-    # Test the cases of tree attn with cached kv.
-    fclear(kv_cache)
-    cached_k = {}
-    cached_v = {}
-    # Prefill 4 sequences
-    apply_attention(kv_cache, rope_mode, [(0, 10), (1, 20), (2, 30), (3, 40)], 
cached_k, cached_v)
-    # Do 5 rounds of tree decode.
-    num_seq = 4
-    for i in range(5):
-        num_leaf_nodes = 2**i
-        parent_ptr = [(k - 1) // 2 for k in range(0, 2 * num_leaf_nodes - 1)]
-        apply_attention(
-            kv_cache,
-            rope_mode,
-            [(seq_id, num_leaf_nodes) for seq_id in range(num_seq)],
-            cached_k,
-            cached_v,
-            token_tree_parent_ptr_list=[parent_ptr for _ in range(num_seq)],
-            accepted_leaf_indices=(
-                None if i != 4 else [2, 6, -1, 4]
-            ),  # Leaf nodes are committed all at once at the end.
-        )
+def init_nvshmem(num_workers, pe_offset):
+    if rank == 0:
+        f_init_nvshmem_uid = 
tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
+        uid = f_init_nvshmem_uid()
+    else:
+        uid = None
+    uid = comm.bcast(uid, root=0)
+    init_func = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem")
+    init_func(uid, num_workers, pe_offset)
 
 
 if __name__ == "__main__":
-    HEAD_DIMS = [64, 128]
-    DTYPES = ["float16", "float32"]
-    ROPE_MODES = [RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE]
-    SUPPORT_SLIDING_WINDOW = [False, True]
+    # To run this test, install mpi4py first, and then run
+    # mpirun -np 2 python 
tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py
+    HEAD_DIMS = [128]
+    DTYPES = ["float16"]
+    ROPE_MODES = [RopeMode.NONE]
+    SUPPORT_SLIDING_WINDOW = [False]
+    init_nvshmem(2, rank)
     for head_dim, dtype, rope_mode, support_sliding_window in 
itertools.product(
         HEAD_DIMS, DTYPES, ROPE_MODES, SUPPORT_SLIDING_WINDOW
     ):
         set_global_func(head_dim, dtype)
         cache = create_kv_cache(head_dim, dtype, rope_mode, 
support_sliding_window)
         cache_and_config = (cache, rope_mode, support_sliding_window)
-        test_paged_attention_kv_cache_prefill_and_decode(cache_and_config)
-        test_paged_attention_kv_cache_remove_sequence(cache_and_config)
-        test_paged_attention_kv_cache_fork_sequence(cache_and_config)
-        test_paged_attention_kv_cache_popn(cache_and_config)
-        test_paged_attention_kv_cache_sliding_window(cache_and_config)
-        test_paged_attention_kv_cache_tree_attn(cache_and_config)
-        test_paged_attention_kv_cache_unlimited_depth(cache_and_config)
+        test_paged_attention_kv_cache_transfer(cache_and_config)
diff --git 
a/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer_kernel.py 
b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer_kernel.py
new file mode 100644
index 0000000000..2a1f0047fa
--- /dev/null
+++ 
b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer_kernel.py
@@ -0,0 +1,252 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import pytest
+
+import tvm
+import tvm.testing
+from tvm._ffi.runtime_ctypes import Device
+from tvm.runtime import ShapeTuple
+from tvm.runtime import disco as di
+
+
+page_size = 4
+num_layers = 4
+num_kv_heads = 4
+head_dim = 128
+num_pages = 100
+ntokens = 16
+
+
+def get_comm_rank():
+    from mpi4py import MPI
+
+    comm = MPI.COMM_WORLD
+    rank = comm.Get_rank()
+    return comm, rank
+
+
[email protected](reason="Require NVSHMEM")
+def test_kv_transfer_without_disco():
+    comm, rank = get_comm_rank()
+    layer_id = 1
+    dev = tvm.cuda(rank)
+    if rank == 0:
+        f_init_nvshmem_uid = 
tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
+        uid = f_init_nvshmem_uid()
+    else:
+        uid = None
+    uid = comm.bcast(uid, root=0)
+    init_func = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem")
+    init_func(uid, 2, rank)
+    empty_func = tvm.get_global_func("runtime.disco.nvshmem.empty")
+    pages = empty_func(
+        ShapeTuple((num_layers, num_pages, 2, num_kv_heads, page_size, 
head_dim)), "float16", dev
+    )
+    position_map_array = [0, 1, 2, 3, 4, 5, 10, 11, 12, 15, 16, 17, 18, 19, 
25, 27]
+    np.random.seed(0)
+    k_np = np.random.rand(ntokens, num_kv_heads, head_dim).astype(np.float16)
+    v_np = np.random.rand(ntokens, num_kv_heads, head_dim).astype(np.float16)
+    if rank == 0:
+        k = tvm.nd.array(k_np, dev)
+        v = tvm.nd.array(v_np, dev)
+        remote_position_map_np = np.array(position_map_array, dtype=np.int32)
+        remote_position_map = tvm.nd.array(remote_position_map_np, dev)
+        remote_tp_group_pe_offset_np = np.array([1] * len(position_map_array), 
dtype=np.int32)
+        remote_tp_group_pe_offset = tvm.nd.array(remote_tp_group_pe_offset_np, 
dev)
+        transfer_func = tvm.get_global_func("nvshmem.KVTransfer")
+        layer_view = pages._create_view(
+            [num_pages, 2, num_kv_heads, page_size, head_dim],
+            "float16",
+            relative_byte_offset=layer_id * num_pages * 2 * num_kv_heads * 
page_size * head_dim * 2,
+        )
+        transfer_func(layer_view, k, v, remote_position_map, 
remote_tp_group_pe_offset, None)
+        dev.sync()
+        comm.Barrier()
+    else:
+        comm.Barrier()
+        pages_np = pages.numpy()
+        for i, position in enumerate(position_map_array):
+            page_id = position // page_size
+            offset_in_page = position % page_size
+            original_k = k_np[i]
+            transferred_k = pages_np[layer_id, page_id, 0, :, offset_in_page, 
:]
+            np.testing.assert_allclose(original_k, transferred_k)
+            original_v = v_np[i]
+            transferred_v = pages_np[layer_id, page_id, 1, :, offset_in_page, 
:]
+            np.testing.assert_allclose(original_v, transferred_v)
+    finalize_func = 
tvm.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
+    finalize_func()
+    comm.Barrier()
+
+
[email protected](reason="Require NVSHMEM")
+def test_kv_transfer_page_to_page_without_disco():
+    comm, rank = get_comm_rank()
+    layer_id = 1
+    dev = tvm.cuda(rank)
+    if rank == 0:
+        f_init_nvshmem_uid = 
tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
+        uid = f_init_nvshmem_uid()
+    else:
+        uid = None
+    uid = comm.bcast(uid, root=0)
+    init_func = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem")
+    init_func(uid, 2, rank)
+    empty_func = tvm.get_global_func("runtime.disco.nvshmem.empty")
+    pages = empty_func(
+        ShapeTuple((num_layers, num_pages, 2, num_kv_heads, page_size, 
head_dim)), "float16", dev
+    )
+    rank_1_position_map_array = [0, 1, 2, 3, 4, 5, 10, 11, 12, 15, 16, 17, 18, 
19, 25, 27]
+    rank_0_position_map_array = list(reversed(rank_1_position_map_array))
+    np.random.seed(0)
+    pages_np = np.random.rand(num_layers, num_pages, 2, num_kv_heads, 
page_size, head_dim).astype(
+        np.float16
+    )
+    if rank == 0:
+        pages.copyfrom(pages_np)
+        remote_position_map_np = np.array(rank_1_position_map_array, 
dtype=np.int32)
+        remote_position_map = tvm.nd.array(remote_position_map_np, dev)
+        local_position_map_np = np.array(rank_0_position_map_array, 
dtype=np.int32)
+        local_position_map = tvm.nd.array(local_position_map_np, dev)
+        remote_tp_group_pe_offset_np = np.array(
+            [1] * len(rank_0_position_map_array), dtype=np.int32
+        )
+        remote_tp_group_pe_offset = tvm.nd.array(remote_tp_group_pe_offset_np, 
dev)
+        transfer_func = tvm.get_global_func("nvshmem.KVTransferPageToPage")
+        layer_view = pages._create_view(
+            [num_pages, 2, num_kv_heads, page_size, head_dim],
+            "float16",
+            relative_byte_offset=layer_id * num_pages * 2 * num_kv_heads * 
page_size * head_dim * 2,
+        )
+        transfer_func(
+            layer_view,
+            layer_view,
+            remote_position_map,
+            local_position_map,
+            remote_tp_group_pe_offset,
+            None,
+        )
+        dev.sync()
+        comm.Barrier()
+    else:
+        comm.Barrier()
+        new_pages_np = pages.numpy()
+        for i, position in enumerate(rank_1_position_map_array):
+            page_id = position // page_size
+            offset_in_page = position % page_size
+            rank_0_position = rank_0_position_map_array[i]
+            rank_0_page_id = rank_0_position // page_size
+            rank_0_offset_in_page = rank_0_position % page_size
+            rank_0_entry = pages_np[layer_id, rank_0_page_id, :, :, 
rank_0_offset_in_page, :]
+            transferred_entry = new_pages_np[layer_id, page_id, :, :, 
offset_in_page, :]
+            np.testing.assert_allclose(rank_0_entry, transferred_entry)
+    finalize_func = 
tvm.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
+    finalize_func()
+    comm.Barrier()
+
+
[email protected](reason="Require NVSHMEM")
+def test_kv_transfer_with_disco():
+    comm, rank = get_comm_rank()
+    layer_id = 1
+    if rank == 0:
+        f_init_nvshmem_uid = 
tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
+        uid = f_init_nvshmem_uid()
+    else:
+        uid = None
+    uid = comm.bcast(uid, root=0)
+    sess = di.ProcessSession(num_workers=2)
+    init_func = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem")
+    init_func(uid, 4, rank * 2)
+    empty_func = sess.get_global_func("runtime.disco.nvshmem.empty")
+    pages = empty_func(
+        ShapeTuple((num_layers, num_pages, 2, num_kv_heads, page_size, 
head_dim)),
+        "float16",
+        Device(device_type=0, device_id=0),
+    )
+    position_map_array = [0, 1, 2, 3, 4, 5, 10, 11, 12, 15, 16, 17, 18, 19, 
25, 27]
+    np.random.seed(0)
+    k_np_0 = np.random.rand(ntokens, num_kv_heads, head_dim).astype(np.float16)
+    v_np_0 = np.random.rand(ntokens, num_kv_heads, head_dim).astype(np.float16)
+    np.random.seed(1)
+    k_np_1 = np.random.rand(ntokens, num_kv_heads, head_dim).astype(np.float16)
+    v_np_1 = np.random.rand(ntokens, num_kv_heads, head_dim).astype(np.float16)
+    if rank == 0:
+        k = sess.empty((ntokens, num_kv_heads, head_dim), "float16")
+        v = sess.empty((ntokens, num_kv_heads, head_dim), "float16")
+        k.debug_copy_from(0, k_np_0)
+        k.debug_copy_from(1, k_np_1)
+        v.debug_copy_from(0, v_np_0)
+        v.debug_copy_from(1, v_np_1)
+        remote_position_map_np = np.array(position_map_array, dtype=np.int32)
+        remote_position_map = sess.empty((len(position_map_array),), "int32")
+        remote_tp_group_pe_offset_np = np.array([2] * len(position_map_array), 
dtype=np.int32)
+        remote_tp_group_pe_offset = 
sess.empty((len(remote_tp_group_pe_offset_np),), "int32")
+        f_view_func = sess.get_global_func("runtime.TVMArrayCreateView")
+        layer_view = f_view_func(
+            pages,
+            ShapeTuple([num_pages, 2, num_kv_heads, page_size, head_dim]),
+            "float16",
+            layer_id * num_pages * 2 * num_kv_heads * page_size * head_dim * 2,
+        )
+        remote_position_map.debug_copy_from(0, remote_position_map_np)
+        remote_position_map.debug_copy_from(1, remote_position_map_np)
+        remote_tp_group_pe_offset.debug_copy_from(0, 
remote_tp_group_pe_offset_np)
+        remote_tp_group_pe_offset.debug_copy_from(1, 
remote_tp_group_pe_offset_np)
+        transfer_func = sess.get_global_func("nvshmem.KVTransfer")
+        transfer_func(layer_view, k, v, remote_position_map, 
remote_tp_group_pe_offset, None)
+        for i in range(2):
+            sess._sync_worker(i)
+        for i in range(2):
+            tvm.cuda(i).sync()
+        comm.Barrier()
+    else:
+        comm.Barrier()
+        pages_np = pages.debug_get_from_remote(0).numpy()
+        for i, position in enumerate(position_map_array):
+            page_id = position // page_size
+            offset_in_page = position % page_size
+            original_k = k_np_0[i]
+            transferred_k = pages_np[layer_id, page_id, 0, :, offset_in_page, 
:]
+            np.testing.assert_allclose(original_k, transferred_k)
+            original_v = v_np_0[i]
+            transferred_v = pages_np[layer_id, page_id, 1, :, offset_in_page, 
:]
+            np.testing.assert_allclose(original_v, transferred_v)
+        pages_np = pages.debug_get_from_remote(1).numpy()
+        for i, position in enumerate(position_map_array):
+            page_id = position // page_size
+            offset_in_page = position % page_size
+            original_k = k_np_1[i]
+            transferred_k = pages_np[layer_id, page_id, 0, :, offset_in_page, 
:]
+            np.testing.assert_allclose(original_k, transferred_k)
+            original_v = v_np_1[i]
+            transferred_v = pages_np[layer_id, page_id, 1, :, offset_in_page, 
:]
+            np.testing.assert_allclose(original_v, transferred_v)
+    finalize_dfunc = 
sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
+    finalize_dfunc()
+    for i in range(2):
+        sess._sync_worker(i)
+
+
+if __name__ == "__main__":
+    # To run this test, install mpi4py first, and then run
+    # mpirun -np 2 python 
tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer_kernel.py  # 
pylint: disable=line-too-long
+    # FIXME: only one test can be run at a time
+    test_kv_transfer_without_disco()
+    # test_kv_transfer_with_disco()
+    # test_kv_transfer_page_to_page_without_disco()
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
index 82f85f4b17..172eb20c26 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
@@ -193,6 +193,7 @@ def create_kv_cache(head_dim, dtype, rope_mode, 
support_sliding_window):
         fattn_prefill_with_tree_mask,
         fattn_prefill_with_tree_mask_paged_kv_cache,
         None,
+        False,
     )
     return cache
 

Reply via email to