This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c2866380a4 [relax] Fix tree attention for Qwen2-1.5 models (#17700)
c2866380a4 is described below

commit c2866380a42208439039333fdb9a08d2a96457b1
Author: Siyuan Feng <[email protected]>
AuthorDate: Tue Mar 4 00:20:16 2025 +0800

    [relax] Fix tree attention for Qwen2-1.5 models (#17700)
    
    Fix the compilation error for Qwen2-1.5 models in the tree attention
    implementation for vulkan backend.
---
 python/tvm/relax/frontend/nn/llm/tree_attn.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/llm/tree_attn.py 
b/python/tvm/relax/frontend/nn/llm/tree_attn.py
index 36a6e2dab8..3a666fb291 100644
--- a/python/tvm/relax/frontend/nn/llm/tree_attn.py
+++ b/python/tvm/relax/frontend/nn/llm/tree_attn.py
@@ -425,8 +425,8 @@ def tree_attn(
                                         batch_tiles[0] = 
T.ceildiv(batch_rows[0], tile_x)
 
                                 if T.tvm_thread_invariant(batch_idx[0] < 
batch_size_plus_1 - 1):
-                                    b_idx: T.int32 = batch_idx[0]
-                                    LH_start: T.int32 = tile_id[0] * tile_x
+                                    b_idx: T.int32(is_size_var=True) = 
batch_idx[0]
+                                    LH_start: T.int32(is_size_var=True) = 
tile_id[0] * tile_x
                                     q_indptr_val: T.int32 = q_indptr[b_idx]
 
                                     kv_chunk_len[0] = kv_indptr[b_idx + 1] - 
kv_indptr[b_idx]
@@ -1049,8 +1049,8 @@ def tree_attn_with_paged_kv_cache(
                                         batch_tiles[0] = 
T.ceildiv(batch_rows[0], tile_x)
 
                                 if T.tvm_thread_invariant(batch_idx[0] < 
batch_size):
-                                    b_idx: T.int32 = batch_idx[0]
-                                    LH_start: T.int32 = tile_id[0] * tile_x
+                                    b_idx: T.int32(is_size_var=True) = 
batch_idx[0]
+                                    LH_start: T.int32(is_size_var=True) = 
tile_id[0] * tile_x
                                     q_indptr_val: T.int32 = q_indptr[b_idx]
 
                                     cur_page_indptr_begin: T.int32 = 
page_indptr[b_idx]

Reply via email to