This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2410c50cec [Fix] nn.attention support dynamic batch_size (#19779)
2410c50cec is described below

commit 2410c50cece24f24d4009074caf2aafcabe53144
Author: flashmouse <[email protected]>
AuthorDate: Tue Jun 16 03:53:02 2026 +0800

    [Fix] nn.attention support dynamic batch_size (#19779)
    
    This PR try to fix #19696 , ``nn.attention`` support dynamic batch_size
    
    Co-authored-by: flashmouse <[email protected]>
---
 python/tvm/relax/transform/legalize_ops/nn.py      | 13 ++++----
 .../python/relax/test_transform_legalize_ops_nn.py | 35 ++++++++++++++++++++++
 2 files changed, 42 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relax/transform/legalize_ops/nn.py 
b/python/tvm/relax/transform/legalize_ops/nn.py
index 51d23de0f7..35d81f968b 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -714,10 +714,11 @@ def _te_attention(
     q = topi.transpose(q, [0, 2, 1, 3])
     k = topi.transpose(k, [0, 2, 1, 3])
     v = topi.transpose(v, [0, 2, 1, 3])
-    q = topi.reshape(q, [batch_size * num_head, seq_len, head_dim])
-    k = topi.reshape(k, [batch_size * num_head, seq_len_kv, head_dim])
-    v = topi.reshape(v, [batch_size * num_head, seq_len_kv, head_dim_v])
-    p = topi.nn.batch_matmul(q, k)
+    bs = batch_size * num_head
+    q = topi.reshape(q, [bs, seq_len, head_dim])
+    k = topi.reshape(k, [bs, seq_len_kv, head_dim])
+    v = topi.reshape(v, [bs, seq_len_kv, head_dim_v])
+    p = topi.nn.batch_matmul(q, k, oshape=[bs, seq_len, seq_len_kv])
     if scale is not None:
         p = topi.multiply(p, scale)
     else:
@@ -725,7 +726,7 @@ def _te_attention(
     if bias is not None:
         p = topi.reshape(p, [batch_size, num_head, seq_len, seq_len_kv])
         p = topi.add(p, bias)
-        p = topi.reshape(p, [batch_size * num_head, seq_len, seq_len_kv])
+        p = topi.reshape(p, [bs, seq_len, seq_len_kv])
     if causal_mask is None:
         s = topi.nn.softmax(p)
     else:
@@ -741,7 +742,7 @@ def _te_attention(
         )
         p_masked_sum = topi.sum(p_masked_exp, axis=-1, keepdims=True)
         s = topi.divide(p_masked_exp, p_masked_sum)
-    o = topi.nn.batch_matmul(s, v, transpose_b=False)
+    o = topi.nn.batch_matmul(s, v, transpose_b=False, oshape=[bs, seq_len, 
head_dim_v])
     o = topi.reshape(o, [batch_size, num_head, seq_len, head_dim_v])
     return topi.transpose(o, [0, 2, 1, 3])
 
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py 
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 4a708b5da1..8136997cf6 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -3727,6 +3727,41 @@ def test_dynamic_attention():
     LegalizeOps()(Attention)
 
 
+def test_dynamic_batch_attention():
+    """The batch dimension may be dynamic (symbolic).
+
+    fix https://github.com/apache/tvm/issues/19696
+    """
+
+    @tvm.script.ir_module
+    class Attention:
+        @R.function
+        def main(
+            q: R.Tensor(("batch_size", 16, 32, 8), "float32"),
+            k: R.Tensor(("batch_size", 8, 32, 8), "float32"),
+            v: R.Tensor(("batch_size", 8, 32, 16), "float32"),
+        ):
+            gv = R.nn.attention(q, k, v)
+            return gv
+
+    LegalizeOps()(Attention)
+
+    @tvm.script.ir_module
+    class AttentionBias:
+        @R.function
+        def main(
+            q: R.Tensor(("batch_size", 16, 32, 8), "float32"),
+            k: R.Tensor(("batch_size", 8, 32, 8), "float32"),
+            v: R.Tensor(("batch_size", 8, 32, 16), "float32"),
+            bias: R.Tensor(("batch_size", 32, 16, 8), "float32"),
+        ):
+            scale = T.FloatImm("float32", 0.1)
+            gv = R.nn.attention(q, k, v, bias, scale=scale, 
causal_mask="BottomRight")
+            return gv
+
+    LegalizeOps()(AttentionBias)
+
+
 def test_nll_loss():
     # fmt: off
     @tvm.script.ir_module

Reply via email to