This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 67987c4592 [Runtime][KVCache] Adapt FlashInfer attention backend to
0.6.3 (#19904)
67987c4592 is described below
commit 67987c45929766d8f89712d58282767fe6aff805
Author: Ruihang Lai <[email protected]>
AuthorDate: Tue Jun 30 15:22:59 2026 -0400
[Runtime][KVCache] Adapt FlashInfer attention backend to 0.6.3 (#19904)
FlashInfer 0.6.3 changes the paged-attention plan/run ABI: the
prefill/decode plans take new arguments (e.g. window_left,
fixed_split_size, disable_split_kv; the decode plan now dispatches dtype
through empty q/kv tensors), the runs add enable_pdl and drop the
explicit stream, and the kernels consume separate key/value paged caches
read through tensor strides rather than one combined tensor. This
updates
the runtime attention backend (paged MHA, ragged, decode and MLA) to the
new signatures and to the Array<int64_t> plan-info representation.
FlashInfer 0.6.3 reads tensors from `data` directly and does not honor
the DLPack `byte_offset` field. mlc's auxiliary index tensors
(qo_indptr,
kv_indptr, page_indptr, page_indices, length_info) are views packed into
a shared workspace and so carry a non-zero byte_offset; passed as-is the
kernels read the wrong addresses (e.g. a ragged prefill processed only
the first query row). Three zero-copy DLPack view helpers address this:
`ZeroByteOffsetView` folds byte_offset into the data pointer,
`PagedKVCacheView` exposes the combined (num_pages, 2, ...) page tensor
as separate strided key/value caches, and `SliceLastDimView` slices the
last dimension for MLA.
This also completes the MLA FlashInfer path, which previously shipped
only the test and module generator. The MLA run splits the query into
nope/pe parts and the paged cache into ckv/kpe parts, and the ragged
self-attention is given its own uncompressed head dims and per-query kv
head count via a 5-element backend spec, since they differ from the
compressed MLA cache.
The MHA and MLA FlashInfer KV-cache tests are re-enabled as regression
coverage, guarded on FlashInfer availability (inline-RoPE is skipped as
unsupported by FlashInfer).
---
src/runtime/vm/attn_backend.cc | 14 +-
src/runtime/vm/attn_backend.h | 277 +++++++++++++++++----
src/runtime/vm/paged_kv_cache.cc | 12 +
..._builtin_paged_attention_kv_cache_flashinfer.py | 24 +-
...ltin_paged_attention_kv_cache_mla_flashinfer.py | 17 +-
5 files changed, 283 insertions(+), 61 deletions(-)
diff --git a/src/runtime/vm/attn_backend.cc b/src/runtime/vm/attn_backend.cc
index fdc88eb3b0..bb74d0e21a 100644
--- a/src/runtime/vm/attn_backend.cc
+++ b/src/runtime/vm/attn_backend.cc
@@ -59,11 +59,21 @@ std::unique_ptr<RaggedPrefillFunc>
ConvertRaggedPrefillFunc(ffi::Array<ffi::Any>
return std::make_unique<TIRRaggedPrefillFunc>(std::move(attn_func),
attn_kind);
}
if (backend_name == "flashinfer") {
- TVM_FFI_ICHECK_EQ(args.size(), 3);
+ // Regular MHA passes [backend, run, plan]. MLA self-attention additionally
+ // passes [qk_head_dim, v_head_dim] because the ragged kernel runs on
+ // different head dims than the compressed MLA cache.
+ TVM_FFI_ICHECK(args.size() == 3 || args.size() == 5);
ffi::Function attn_func = args[1].cast<ffi::Function>();
ffi::Function plan_func = args[2].cast<ffi::Function>();
+ int64_t qk_head_dim_override = -1;
+ int64_t v_head_dim_override = -1;
+ if (args.size() == 5) {
+ qk_head_dim_override = args[3].cast<int64_t>();
+ v_head_dim_override = args[4].cast<int64_t>();
+ }
return std::make_unique<FlashInferRaggedPrefillFunc>(std::move(attn_func),
std::move(plan_func),
- attn_kind);
+ attn_kind,
qk_head_dim_override,
+ v_head_dim_override);
}
TVM_FFI_THROW(InternalError) << "Cannot reach here";
throw;
diff --git a/src/runtime/vm/attn_backend.h b/src/runtime/vm/attn_backend.h
index 6aececc755..64c2a872d6 100644
--- a/src/runtime/vm/attn_backend.h
+++ b/src/runtime/vm/attn_backend.h
@@ -49,6 +49,134 @@ enum class AttnBackendKind : int {
kFlashInfer = 1,
};
+/*!
+ * \brief Return a zero-copy alias of \p t whose `byte_offset` is folded into
the
+ * data pointer, so the resulting tensor has `byte_offset == 0`.
+ *
+ * FlashInfer 0.6.3 kernels read tensors from `data` directly and do NOT honor
+ * the DLPack `byte_offset` field. mlc's auxiliary index tensors (qo_indptr,
+ * kv_indptr, page_indptr, page_indices, length_info, ...) are views packed
into
+ * a shared workspace and therefore carry a non-zero `byte_offset`. Passing
them
+ * as-is makes FlashInfer read the wrong addresses; this helper rebases them.
+ */
+inline ffi::Tensor ZeroByteOffsetView(const Tensor& t) {
+ if (t->byte_offset == 0) return t;
+ auto* holder = new Tensor(t); // keep the underlying storage alive
+ auto* managed = new DLManagedTensor();
+ managed->manager_ctx = holder;
+ managed->deleter = [](DLManagedTensor* self) {
+ delete[] self->dl_tensor.shape;
+ delete[] self->dl_tensor.strides;
+ delete static_cast<Tensor*>(self->manager_ctx);
+ delete self;
+ };
+ DLTensor& dl = managed->dl_tensor;
+ dl.data = static_cast<void*>(static_cast<char*>(t->data) + t->byte_offset);
+ dl.device = t->device;
+ dl.ndim = t->ndim;
+ dl.dtype = t->dtype;
+ dl.shape = new int64_t[t->ndim];
+ dl.strides = nullptr;
+ for (int i = 0; i < t->ndim; ++i) dl.shape[i] = t->shape[i];
+ if (t->strides != nullptr) {
+ dl.strides = new int64_t[t->ndim];
+ for (int i = 0; i < t->ndim; ++i) dl.strides[i] = t->strides[i];
+ }
+ dl.byte_offset = 0;
+ return tvm::ffi::Tensor::FromDLPack(managed, /*require_alignment=*/0,
+ /*require_contiguous=*/false);
+}
+
+/*!
+ * \brief Build a strided, zero-copy view selecting the key (which=0) or value
+ * (which=1) sub-tensor from a combined paged KV tensor of shape
+ * (num_pages, 2, num_heads, page_size, head_dim), yielding a
+ * (num_pages, num_heads, page_size, head_dim) tensor that shares storage with
+ * `pages`. FlashInfer 0.6.3 takes separate key/value paged caches and reads
the
+ * tensor strides, so a strided view avoids an explicit split/copy.
+ */
+inline ffi::Tensor PagedKVCacheView(const Tensor& pages, int64_t which) {
+ TVM_FFI_ICHECK_EQ(pages->ndim, 5);
+ TVM_FFI_ICHECK_EQ(pages->shape[1], 2);
+ int64_t num_pages = pages->shape[0];
+ int64_t num_heads = pages->shape[2];
+ int64_t page_size = pages->shape[3];
+ int64_t head_dim = pages->shape[4];
+ int64_t inner = num_heads * page_size * head_dim;
+ int64_t elem_bytes = (pages->dtype.bits * pages->dtype.lanes + 7) / 8;
+
+ auto* holder = new Tensor(pages); // keep the underlying storage alive
+ auto* managed = new DLManagedTensor();
+ managed->manager_ctx = holder;
+ managed->deleter = [](DLManagedTensor* self) {
+ delete[] self->dl_tensor.shape;
+ delete[] self->dl_tensor.strides;
+ delete static_cast<Tensor*>(self->manager_ctx);
+ delete self;
+ };
+ DLTensor& dl = managed->dl_tensor;
+ dl.data = static_cast<void*>(static_cast<char*>(pages->data) +
pages->byte_offset +
+ which * inner * elem_bytes);
+ dl.device = pages->device;
+ dl.ndim = 4;
+ dl.dtype = pages->dtype;
+ dl.shape = new int64_t[4]{num_pages, num_heads, page_size, head_dim};
+ dl.strides = new int64_t[4]{2 * inner, page_size * head_dim, head_dim, 1};
+ dl.byte_offset = 0;
+ return tvm::ffi::Tensor::FromDLPack(managed, /*require_alignment=*/0,
+ /*require_contiguous=*/false);
+}
+
+/*!
+ * \brief Return a strided, zero-copy view selecting the `[start,
start+length)`
+ * slice along the LAST dimension of \p t, preserving all other strides and
+ * folding the slice offset into the data pointer (so `byte_offset == 0`).
+ *
+ * Used to split MLA tensors that store two head components concatenated along
+ * the last dim: the query into `q_nope`/`q_pe` and the paged cache into
+ * `ckv_cache`/`kpe_cache`. FlashInfer reads tensor strides and ignores
+ * `byte_offset`, so a strided slice avoids a copy.
+ */
+inline ffi::Tensor SliceLastDimView(const Tensor& t, int64_t start, int64_t
length) {
+ int ndim = t->ndim;
+ int64_t elem_bytes = (t->dtype.bits * t->dtype.lanes + 7) / 8;
+ std::vector<int64_t> in_strides(ndim);
+ if (t->strides != nullptr) {
+ for (int i = 0; i < ndim; ++i) in_strides[i] = t->strides[i];
+ } else {
+ int64_t s = 1;
+ for (int i = ndim - 1; i >= 0; --i) {
+ in_strides[i] = s;
+ s *= t->shape[i];
+ }
+ }
+ auto* holder = new Tensor(t); // keep the underlying storage alive
+ auto* managed = new DLManagedTensor();
+ managed->manager_ctx = holder;
+ managed->deleter = [](DLManagedTensor* self) {
+ delete[] self->dl_tensor.shape;
+ delete[] self->dl_tensor.strides;
+ delete static_cast<Tensor*>(self->manager_ctx);
+ delete self;
+ };
+ DLTensor& dl = managed->dl_tensor;
+ dl.data = static_cast<void*>(static_cast<char*>(t->data) + t->byte_offset +
+ start * in_strides[ndim - 1] * elem_bytes);
+ dl.device = t->device;
+ dl.ndim = ndim;
+ dl.dtype = t->dtype;
+ dl.shape = new int64_t[ndim];
+ dl.strides = new int64_t[ndim];
+ for (int i = 0; i < ndim; ++i) {
+ dl.shape[i] = t->shape[i];
+ dl.strides[i] = in_strides[i];
+ }
+ dl.shape[ndim - 1] = length;
+ dl.byte_offset = 0;
+ return tvm::ffi::Tensor::FromDLPack(managed, /*require_alignment=*/0,
+ /*require_contiguous=*/false);
+}
+
/*! \brief The base class of attention backends. */
class AttnBackendFunc {
public:
@@ -139,12 +267,13 @@ class FlashInferPagedPrefillFunc : public
PagedPrefillFunc {
plan_info_vec] = cached_buffers_[depth];
double rope_rcp_scale = 1 / rotary_scale;
double rope_rcp_theta = 1 / rotary_theta;
- attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q,
pages, qo_indptr,
- page_indptr, page_indices, length_info, q_rope_position,
k_rope_pos_offset,
- attn_output, attn_lse,
/*mask_mode_code=*/static_cast<int64_t>(causal),
- /*pos_encoding_mode_code=*/static_cast<int64_t>(rope_mode ==
RoPEMode::kInline),
- /*layout(HND)=*/1, -1, sm_scale,
/*rope_rcp_scale=*/rope_rcp_scale,
- /*rope_rcp_theta=*/rope_rcp_theta, compute_stream);
+ attn_func_(
+ float_workspace_buffer, int_workspace_buffer, plan_info_vec, q,
PagedKVCacheView(pages, 0),
+ PagedKVCacheView(pages, 1), ZeroByteOffsetView(qo_indptr),
ZeroByteOffsetView(page_indptr),
+ ZeroByteOffsetView(page_indices), ZeroByteOffsetView(length_info),
attn_output, attn_lse,
+ /*mask_mode_code=*/static_cast<int64_t>(causal),
+ /*layout(HND)=*/1, /*window_left=*/-1, /*enable_pdl=*/false, sm_scale,
+ /*rope_rcp_scale=*/rope_rcp_scale, /*rope_rcp_theta=*/rope_rcp_theta);
}
void MLA(int depth, Tensor q, Tensor qo_indptr, Tensor pages, Tensor
page_indptr,
@@ -152,9 +281,23 @@ class FlashInferPagedPrefillFunc : public PagedPrefillFunc
{
Tensor attn_output, Tensor attn_lse, TVMStreamHandle
compute_stream) final {
auto [float_workspace_buffer, int_workspace_buffer,
page_locked_int_workspace_buffer,
plan_info_vec] = cached_buffers_[depth];
- attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q,
pages, page_indices,
- attn_output, attn_lse,
/*mask_mode_code=*/static_cast<int64_t>(causal),
- /*num_heads=*/q->shape[1], /*page_size=*/pages->shape[1],
sm_scale, compute_stream);
+ // FlashInfer's MLA run takes the query split into its compressed (nope)
and
+ // positional-embedding (pe) parts, and the paged cache split into the
+ // compressed-kv cache (ckv) and key-positional-embedding cache (kpe). Both
+ // q ([n, num_heads, ckv+kpe]) and pages ([num_pages, page_size, ckv+kpe])
+ // store the two components concatenated along the last dimension.
+ int64_t head_dim_ckv = mla_head_dim_ckv_;
+ int64_t head_dim_kpe = mla_head_dim_kpe_;
+ TVM_FFI_ICHECK_GE(head_dim_ckv, 0)
+ << "MLA head dims are unset; BeginForward must run before MLA.";
+ attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec,
+ SliceLastDimView(q, 0, head_dim_ckv),
+ SliceLastDimView(q, head_dim_ckv, head_dim_kpe),
+ SliceLastDimView(pages, 0, head_dim_ckv),
+ SliceLastDimView(pages, head_dim_ckv, head_dim_kpe),
+ ZeroByteOffsetView(page_indices), attn_output, attn_lse,
+ /*mask_mode_code=*/static_cast<int64_t>(causal),
/*num_heads=*/q->shape[1],
+ /*page_size=*/pages->shape[1], sm_scale,
/*return_lse_base_on_e=*/false);
}
void BeginForward(int depth, Tensor float_workspace_buffer, Tensor
int_workspace_buffer,
@@ -163,31 +306,37 @@ class FlashInferPagedPrefillFunc : public
PagedPrefillFunc {
int64_t batch_size, int64_t total_qo_len, int64_t
page_size,
int64_t num_qo_heads, int64_t num_kv_heads, int64_t
qk_head_dim,
int64_t v_head_dim, bool causal, TVMStreamHandle
copy_stream) final {
- std::vector<int64_t> kv_len;
- kv_len.reserve(batch_size);
+ // FlashInfer expects kv_len as an (int32) tensor rather than a shape
tuple.
+ HostMemoryVector kv_len_arr(batch_size, DLDataType{kDLInt, 32, 1},
+ qo_indptr->as_tensor()->device);
for (int i = 0; i < static_cast<int>(batch_size); ++i) {
- kv_len.push_back((*page_indptr)[i + 1] != (*page_indptr)[i]
- ? ((*page_indptr)[i + 1] - (*page_indptr)[i] - 1) *
page_size +
- (*last_page_len)[i]
- : 0);
+ kv_len_arr.push_back(static_cast<int32_t>(
+ (*page_indptr)[i + 1] != (*page_indptr)[i]
+ ? ((*page_indptr)[i + 1] - (*page_indptr)[i] - 1) * page_size +
(*last_page_len)[i]
+ : 0));
}
- ffi::Shape plan_info_vec;
+ ffi::Array<int64_t> plan_info_vec;
if (attn_kind == AttnKind::kMHA) {
// Todo(tvm-team): enable cuda graph
plan_info_vec =
plan_func_(float_workspace_buffer, int_workspace_buffer,
page_locked_int_workspace_buffer,
- qo_indptr->as_tensor(), page_indptr->as_tensor(),
- ffi::Shape(std::move(kv_len)), total_qo_len, batch_size,
num_qo_heads,
- num_kv_heads, page_size,
- /*enable_cuda_graph=*/false, qk_head_dim, v_head_dim,
causal, copy_stream,
+ qo_indptr->as_tensor(), page_indptr->as_tensor(),
kv_len_arr.as_tensor(),
+ total_qo_len, batch_size, num_qo_heads, num_kv_heads,
page_size,
+ /*enable_cuda_graph=*/false, qk_head_dim, v_head_dim,
causal,
+ /*window_left=*/-1, /*fixed_split_size=*/-1,
/*disable_split_kv=*/false,
/*num_colocated_ctas=*/0)
- .cast<ffi::Shape>();
+ .cast<ffi::Array<int64_t>>();
} else if (attn_kind == AttnKind::kMLA) {
+ // For MLA the compressed-kv head dim equals the output (v) head dim, and
+ // the remaining part of qk_head_dim is the key positional embedding.
Cache
+ // them for the run, which must split q/pages into ckv and kpe
components.
+ mla_head_dim_ckv_ = v_head_dim;
+ mla_head_dim_kpe_ = qk_head_dim - v_head_dim;
plan_info_vec =
plan_func_(float_workspace_buffer, int_workspace_buffer,
page_locked_int_workspace_buffer,
- qo_indptr->as_tensor(), page_indptr->as_tensor(),
- ffi::Shape(std::move(kv_len)), num_qo_heads, v_head_dim,
causal, copy_stream)
- .cast<ffi::Shape>();
+ qo_indptr->as_tensor(), page_indptr->as_tensor(),
kv_len_arr.as_tensor(),
+ num_qo_heads, v_head_dim, causal)
+ .cast<ffi::Array<int64_t>>();
}
if (cached_buffers_.size() <= static_cast<size_t>(depth)) {
@@ -200,7 +349,11 @@ class FlashInferPagedPrefillFunc : public PagedPrefillFunc
{
private:
ffi::Function plan_func_;
- std::vector<std::tuple<Tensor, Tensor, Tensor, ffi::Shape>> cached_buffers_;
+ std::vector<std::tuple<Tensor, Tensor, Tensor, ffi::Array<int64_t>>>
cached_buffers_;
+ // MLA-only: the compressed-kv and key-positional-embedding head dims, used
to
+ // split q/pages in the run. Set during BeginForward for the kMLA attn kind.
+ int64_t mla_head_dim_ckv_ = -1;
+ int64_t mla_head_dim_kpe_ = -1;
};
/*! \brief The ragged prefill attention function base class. */
@@ -247,9 +400,12 @@ class TIRRaggedPrefillFunc : public RaggedPrefillFunc {
class FlashInferRaggedPrefillFunc : public RaggedPrefillFunc {
public:
explicit FlashInferRaggedPrefillFunc(ffi::Function attn_func, ffi::Function
plan_func,
- AttnKind attn_kind)
+ AttnKind attn_kind, int64_t
qk_head_dim_override = -1,
+ int64_t v_head_dim_override = -1)
: RaggedPrefillFunc(std::move(attn_func), attn_kind,
AttnBackendKind::kFlashInfer),
- plan_func_(std::move(plan_func)) {}
+ plan_func_(std::move(plan_func)),
+ qk_head_dim_override_(qk_head_dim_override),
+ v_head_dim_override_(v_head_dim_override) {}
void MHA(Tensor q, Tensor k, Tensor v, Tensor qo_indptr, Tensor kv_indptr,
Tensor q_rope_position,
Tensor k_rope_pos_offset, bool causal, RoPEMode rope_mode, double
rotary_scale,
@@ -257,13 +413,12 @@ class FlashInferRaggedPrefillFunc : public
RaggedPrefillFunc {
TVMStreamHandle compute_stream) final {
double rope_rcp_scale = 1 / rotary_scale;
double rope_rcp_theta = 1 / rotary_theta;
- attn_func_(float_workspace_buffer_, int_workspace_buffer_, plan_info_vec_,
q, k, v, qo_indptr,
- kv_indptr, q_rope_position, k_rope_pos_offset, attn_output,
attn_lse,
+ attn_func_(float_workspace_buffer_, int_workspace_buffer_, plan_info_vec_,
q, k, v,
+ ZeroByteOffsetView(qo_indptr), ZeroByteOffsetView(kv_indptr),
attn_output, attn_lse,
/*mask_mode_code=*/static_cast<int64_t>(causal),
- /*pos_encoding_mode_code=*/static_cast<int64_t>(rope_mode ==
RoPEMode::kInline),
- /*layout(NHD)=*/0, /*window_left=*/-1, sm_scale,
+ /*layout(NHD)=*/0, /*window_left=*/-1, /*enable_pdl=*/false,
sm_scale,
/*rope_rcp_scale=*/rope_rcp_scale,
- /*rope_rcp_theta=*/rope_rcp_theta, compute_stream);
+ /*rope_rcp_theta=*/rope_rcp_theta);
}
void BeginForward(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
@@ -271,10 +426,21 @@ class FlashInferRaggedPrefillFunc : public
RaggedPrefillFunc {
HostMemoryVector* kv_indptr, int64_t batch_size, int64_t
total_qo_len,
int64_t num_qo_heads, int64_t num_kv_heads, int64_t
qk_head_dim,
int64_t v_head_dim, bool causal, TVMStreamHandle
copy_stream) final {
- std::vector<int64_t> kv_len;
- kv_len.reserve(batch_size);
+ // For MLA self-attention the ragged kernel operates on different head dims
+ // than the (compressed) MLA cache, so they are supplied per-function via
the
+ // backend spec and override the cache-derived dims passed by the caller.
MLA
+ // self-attention is full multi-head (one kv head per query head), unlike
the
+ // single-head compressed cache, so the kv head count is overridden too.
+ if (qk_head_dim_override_ >= 0) qk_head_dim = qk_head_dim_override_;
+ if (v_head_dim_override_ >= 0) {
+ v_head_dim = v_head_dim_override_;
+ num_kv_heads = num_qo_heads;
+ }
+ // FlashInfer expects kv_len as an (int32) tensor rather than a shape
tuple.
+ HostMemoryVector kv_len_arr(batch_size, DLDataType{kDLInt, 32, 1},
+ qo_indptr->as_tensor()->device);
for (int i = 0; i < static_cast<int>(batch_size); ++i) {
- kv_len.push_back((*kv_indptr)[i + 1] - (*kv_indptr)[i]);
+ kv_len_arr.push_back(static_cast<int32_t>((*kv_indptr)[i + 1] -
(*kv_indptr)[i]));
}
// Todo(tvm-team): enable cuda graph
float_workspace_buffer_ = float_workspace_buffer;
@@ -282,11 +448,12 @@ class FlashInferRaggedPrefillFunc : public
RaggedPrefillFunc {
page_locked_int_workspace_buffer_ = page_locked_int_workspace_buffer;
plan_info_vec_ =
plan_func_(float_workspace_buffer, int_workspace_buffer,
page_locked_int_workspace_buffer,
- qo_indptr->as_tensor(), kv_indptr->as_tensor(),
ffi::Shape(std::move(kv_len)),
+ qo_indptr->as_tensor(), kv_indptr->as_tensor(),
kv_len_arr.as_tensor(),
total_qo_len, batch_size, num_qo_heads, num_kv_heads,
/*page_size=*/1,
- /*enable_cuda_graph=*/false, qk_head_dim, v_head_dim,
causal, copy_stream,
+ /*enable_cuda_graph=*/false, qk_head_dim, v_head_dim,
causal,
+ /*window_left=*/-1, /*fixed_split_size=*/-1,
/*disable_split_kv=*/false,
/*num_colocated_ctas=*/0)
- .cast<ffi::Shape>();
+ .cast<ffi::Array<int64_t>>();
}
private:
@@ -294,7 +461,11 @@ class FlashInferRaggedPrefillFunc : public
RaggedPrefillFunc {
Tensor float_workspace_buffer_;
Tensor int_workspace_buffer_;
Tensor page_locked_int_workspace_buffer_;
- ffi::Shape plan_info_vec_;
+ ffi::Array<int64_t> plan_info_vec_;
+ // MLA self-attention head dims supplied via the backend spec; -1 means use
the
+ // dims passed by the caller (the regular MHA case).
+ int64_t qk_head_dim_override_ = -1;
+ int64_t v_head_dim_override_ = -1;
};
/*! \brief The paged decode attention function base class. */
@@ -366,11 +537,12 @@ class FlashInferPagedDecodeFunc : public PagedDecodeFunc {
plan_info_vec] = cached_buffers_[depth];
double rope_rcp_scale = 1 / rotary_scale;
double rope_rcp_theta = 1 / rotary_theta;
- attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q,
pages, page_indptr,
- page_indices, length_info, q_rope_position, k_rope_pos_offset,
attn_output, attn_lse,
- /*pos_encoding_mode_code=*/static_cast<int64_t>(rope_mode ==
RoPEMode::kInline),
- /*layout(HND)=*/1, /*window_left=*/-1, sm_scale,
/*rope_rcp_scale=*/rope_rcp_scale,
- /*rope_rcp_theta=*/rope_rcp_theta, compute_stream);
+ attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q,
+ PagedKVCacheView(pages, 0), PagedKVCacheView(pages, 1),
+ ZeroByteOffsetView(page_indptr),
ZeroByteOffsetView(page_indices),
+ ZeroByteOffsetView(length_info), attn_output, attn_lse,
/*kv_layout_code(HND)=*/1,
+ /*window_left=*/-1, /*enable_pdl=*/false, sm_scale,
+ /*rope_rcp_scale=*/rope_rcp_scale,
/*rope_rcp_theta=*/rope_rcp_theta);
}
void BeginForward(int depth, Tensor float_workspace_buffer, Tensor
int_workspace_buffer,
@@ -380,13 +552,18 @@ class FlashInferPagedDecodeFunc : public PagedDecodeFunc {
RoPEMode rope_mode, DLDataType q_dtype, DLDataType
kv_dtype,
TVMStreamHandle copy_stream) final {
// Todo(tvm-team): enable cuda graph
- ffi::Shape plan_info_vec =
+ // FlashInfer's decode plan takes empty q/kv tensors (used only for dtype
+ // dispatch) instead of dtype scalars, adds a logits_soft_cap argument, and
+ // no longer takes the pos-encoding mode or an explicit stream.
+ DLDevice device = float_workspace_buffer->device;
+ Tensor empty_q_data = Tensor::Empty(ffi::Shape({0}), q_dtype, device);
+ Tensor empty_kv_data = Tensor::Empty(ffi::Shape({0}), kv_dtype, device);
+ ffi::Array<int64_t> plan_info_vec =
plan_func_(float_workspace_buffer, int_workspace_buffer,
page_locked_int_workspace_buffer,
page_indptr->as_tensor(), batch_size, num_qo_heads,
num_kv_heads, page_size,
- /*enable_cuda_graph=*/false,
- static_cast<int64_t>(rope_mode == RoPEMode::kInline),
- /*window_left=*/-1, qk_head_dim, v_head_dim, q_dtype,
kv_dtype, copy_stream)
- .cast<ffi::Shape>();
+ /*enable_cuda_graph=*/false, /*window_left=*/-1,
/*logits_soft_cap=*/0.0,
+ qk_head_dim, v_head_dim, empty_q_data, empty_kv_data)
+ .cast<ffi::Array<int64_t>>();
if (cached_buffers_.size() <= static_cast<size_t>(depth)) {
cached_buffers_.resize(depth + 1);
@@ -398,7 +575,7 @@ class FlashInferPagedDecodeFunc : public PagedDecodeFunc {
private:
ffi::Function plan_func_;
- std::vector<std::tuple<Tensor, Tensor, Tensor, ffi::Shape>> cached_buffers_;
+ std::vector<std::tuple<Tensor, Tensor, Tensor, ffi::Array<int64_t>>>
cached_buffers_;
};
/*! \brief The paged prefill with tree mask attention function base class. */
diff --git a/src/runtime/vm/paged_kv_cache.cc b/src/runtime/vm/paged_kv_cache.cc
index cd7920d6ee..be8751df7a 100644
--- a/src/runtime/vm/paged_kv_cache.cc
+++ b/src/runtime/vm/paged_kv_cache.cc
@@ -2289,7 +2289,19 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
}
// - Sync Tensors to GPU.
SyncAuxArrayToDevice();
+ // FlashInfer's plan kernels no longer take an explicit stream argument;
they
+ // run on the device's *current* stream. Make the copy stream current
around
+ // the plan so its workspace writes happen on the copy stream -- matching
the
+ // aux-array copies above and the copy->compute synchronization below (this
+ // preserves the prior behavior where copy_stream_ was passed explicitly).
+ bool plan_on_copy_stream = copy_stream_ != nullptr && copy_stream_ !=
compute_stream_;
+ if (plan_on_copy_stream) {
+ DeviceAPI::Get(device_)->SetStream(device_, copy_stream_);
+ }
KernelBeginForward();
+ if (plan_on_copy_stream) {
+ DeviceAPI::Get(device_)->SetStream(device_, compute_stream_);
+ }
// - Clear the dirty flag.
dirty_aux_data_device_ = false;
// - If there is no particular copy stream, no action is needed.
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
index 3a4454e7c9..2e380d0cbd 100644
---
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
+++
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
@@ -35,6 +35,17 @@ from tvm.relax.frontend.nn.llm.kv_cache import (
)
from tvm.s_tir import dlight as dl
+
+def has_flashinfer():
+ """Check whether FlashInfer (with the JIT module generator) is
available."""
+ try:
+ from flashinfer.jit import gen_customize_batch_prefill_module # noqa:
F401
+
+ return True
+ except ImportError:
+ return False
+
+
reserved_nseq = 32
maximum_total_seq_length = 2048
prefill_chunk_size = 512
@@ -195,7 +206,14 @@ def create_kv_cache(rope_mode):
@pytest.fixture(params=[RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE])
def kv_cache_and_rope_mode(request):
- set_global_func()
+ if not has_flashinfer():
+ pytest.skip("FlashInfer is not available")
+ if request.param == RopeMode.INLINE:
+ # FlashInfer does not support inline RoPE (see the assertion in
+ # tvm/relax/frontend/nn/llm/kv_cache.py); models pair FlashInfer with
+ # NORMAL/NONE rope and apply rotary embedding in a separate kernel.
+ pytest.skip("FlashInfer does not support inline RoPE mode")
+ set_global_func(request.param)
return create_kv_cache(request.param), request.param
@@ -410,7 +428,6 @@ def apply_attention(
verify_cached_kv(kv_cache, seq_ids, cached_k, cached_v)
[email protected](reason="Require FlashInfer enabled")
def test_paged_attention_kv_cache_prefill_and_decode(kv_cache_and_rope_mode):
kv_cache, rope_mode = kv_cache_and_rope_mode
fclear(kv_cache)
@@ -431,7 +448,6 @@ def
test_paged_attention_kv_cache_prefill_and_decode(kv_cache_and_rope_mode):
apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
[email protected](reason="Require FlashInfer enabled")
def test_paged_attention_kv_cache_remove_sequence(kv_cache_and_rope_mode):
kv_cache, rope_mode = kv_cache_and_rope_mode
fclear(kv_cache)
@@ -454,7 +470,6 @@ def
test_paged_attention_kv_cache_remove_sequence(kv_cache_and_rope_mode):
)
[email protected](reason="Require FlashInfer enabled")
def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_rope_mode):
kv_cache, rope_mode = kv_cache_and_rope_mode
fclear(kv_cache)
@@ -520,7 +535,6 @@ def
test_paged_attention_kv_cache_fork_sequence(kv_cache_and_rope_mode):
apply_attention(kv_cache, rope_mode, [(10, 1), (12, 1)], cached_k,
cached_v)
[email protected](reason="Require FlashInfer enabled")
def test_paged_attention_kv_cache_popn(kv_cache_and_rope_mode):
kv_cache, rope_mode = kv_cache_and_rope_mode
fclear(kv_cache)
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
index ef2aa35ecd..20c9cb01ed 100644
---
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
+++
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
@@ -35,6 +35,17 @@ from tvm.relax.frontend.nn.llm.kv_cache import (
)
from tvm.s_tir import dlight as dl
+
+def has_flashinfer():
+ """Check whether FlashInfer (with the JIT module generator) is
available."""
+ try:
+ from flashinfer.jit import gen_batch_mla_module # noqa: F401
+
+ return True
+ except ImportError:
+ return False
+
+
np.random.seed(0)
reserved_nseq = 32
@@ -225,6 +236,8 @@ def create_kv_cache(dtype):
@pytest.fixture(params=itertools.product(["float16"]))
def kv_cache_and_config(request):
+ if not has_flashinfer():
+ pytest.skip("FlashInfer is not available")
global dtype, dtype_torch
(dtype,) = request.param
dtype_torch = getattr(torch, dtype)
@@ -430,7 +443,6 @@ def apply_attention(
verify_cached_kv(kv_cache, seq_ids, cached_kv)
[email protected](reason="Require FlashInfer enabled")
def test_paged_attention_kv_cache_prefill_and_decode(kv_cache_and_config):
(kv_cache,) = kv_cache_and_config
fclear(kv_cache)
@@ -450,7 +462,6 @@ def
test_paged_attention_kv_cache_prefill_and_decode(kv_cache_and_config):
apply_attention(kv_cache, batch, cached_kv)
[email protected](reason="Require FlashInfer enabled")
def test_paged_attention_kv_cache_remove_sequence(kv_cache_and_config):
(kv_cache,) = kv_cache_and_config
fclear(kv_cache)
@@ -470,7 +481,6 @@ def
test_paged_attention_kv_cache_remove_sequence(kv_cache_and_config):
)
[email protected](reason="Require FlashInfer enabled")
def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config):
(kv_cache,) = kv_cache_and_config
fclear(kv_cache)
@@ -539,7 +549,6 @@ def
test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config):
apply_attention(kv_cache, [(10, 1), (12, 1)], cached_kv)
[email protected](reason="Require FlashInfer enabled")
def test_paged_attention_kv_cache_popn(kv_cache_and_config):
(kv_cache,) = kv_cache_and_config
fclear(kv_cache)