This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c2866380a4 [relax] Fix tree attention for Qwen2-1.5 models (#17700)
c2866380a4 is described below
commit c2866380a42208439039333fdb9a08d2a96457b1
Author: Siyuan Feng <[email protected]>
AuthorDate: Tue Mar 4 00:20:16 2025 +0800
[relax] Fix tree attention for Qwen2-1.5 models (#17700)
Fix the compilation error for Qwen2-1.5 models in the tree attention
implementation for vulkan backend.
---
python/tvm/relax/frontend/nn/llm/tree_attn.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/llm/tree_attn.py
b/python/tvm/relax/frontend/nn/llm/tree_attn.py
index 36a6e2dab8..3a666fb291 100644
--- a/python/tvm/relax/frontend/nn/llm/tree_attn.py
+++ b/python/tvm/relax/frontend/nn/llm/tree_attn.py
@@ -425,8 +425,8 @@ def tree_attn(
batch_tiles[0] =
T.ceildiv(batch_rows[0], tile_x)
if T.tvm_thread_invariant(batch_idx[0] <
batch_size_plus_1 - 1):
- b_idx: T.int32 = batch_idx[0]
- LH_start: T.int32 = tile_id[0] * tile_x
+ b_idx: T.int32(is_size_var=True) =
batch_idx[0]
+ LH_start: T.int32(is_size_var=True) =
tile_id[0] * tile_x
q_indptr_val: T.int32 = q_indptr[b_idx]
kv_chunk_len[0] = kv_indptr[b_idx + 1] -
kv_indptr[b_idx]
@@ -1049,8 +1049,8 @@ def tree_attn_with_paged_kv_cache(
batch_tiles[0] =
T.ceildiv(batch_rows[0], tile_x)
if T.tvm_thread_invariant(batch_idx[0] <
batch_size):
- b_idx: T.int32 = batch_idx[0]
- LH_start: T.int32 = tile_id[0] * tile_x
+ b_idx: T.int32(is_size_var=True) =
batch_idx[0]
+ LH_start: T.int32(is_size_var=True) =
tile_id[0] * tile_x
q_indptr_val: T.int32 = q_indptr[b_idx]
cur_page_indptr_begin: T.int32 =
page_indptr[b_idx]