This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2410c50cec [Fix] nn.attention support dynamic batch_size (#19779)
2410c50cec is described below
commit 2410c50cece24f24d4009074caf2aafcabe53144
Author: flashmouse <[email protected]>
AuthorDate: Tue Jun 16 03:53:02 2026 +0800
[Fix] nn.attention support dynamic batch_size (#19779)
This PR try to fix #19696 , ``nn.attention`` support dynamic batch_size
Co-authored-by: flashmouse <[email protected]>
---
python/tvm/relax/transform/legalize_ops/nn.py | 13 ++++----
.../python/relax/test_transform_legalize_ops_nn.py | 35 ++++++++++++++++++++++
2 files changed, 42 insertions(+), 6 deletions(-)
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py
b/python/tvm/relax/transform/legalize_ops/nn.py
index 51d23de0f7..35d81f968b 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -714,10 +714,11 @@ def _te_attention(
q = topi.transpose(q, [0, 2, 1, 3])
k = topi.transpose(k, [0, 2, 1, 3])
v = topi.transpose(v, [0, 2, 1, 3])
- q = topi.reshape(q, [batch_size * num_head, seq_len, head_dim])
- k = topi.reshape(k, [batch_size * num_head, seq_len_kv, head_dim])
- v = topi.reshape(v, [batch_size * num_head, seq_len_kv, head_dim_v])
- p = topi.nn.batch_matmul(q, k)
+ bs = batch_size * num_head
+ q = topi.reshape(q, [bs, seq_len, head_dim])
+ k = topi.reshape(k, [bs, seq_len_kv, head_dim])
+ v = topi.reshape(v, [bs, seq_len_kv, head_dim_v])
+ p = topi.nn.batch_matmul(q, k, oshape=[bs, seq_len, seq_len_kv])
if scale is not None:
p = topi.multiply(p, scale)
else:
@@ -725,7 +726,7 @@ def _te_attention(
if bias is not None:
p = topi.reshape(p, [batch_size, num_head, seq_len, seq_len_kv])
p = topi.add(p, bias)
- p = topi.reshape(p, [batch_size * num_head, seq_len, seq_len_kv])
+ p = topi.reshape(p, [bs, seq_len, seq_len_kv])
if causal_mask is None:
s = topi.nn.softmax(p)
else:
@@ -741,7 +742,7 @@ def _te_attention(
)
p_masked_sum = topi.sum(p_masked_exp, axis=-1, keepdims=True)
s = topi.divide(p_masked_exp, p_masked_sum)
- o = topi.nn.batch_matmul(s, v, transpose_b=False)
+ o = topi.nn.batch_matmul(s, v, transpose_b=False, oshape=[bs, seq_len,
head_dim_v])
o = topi.reshape(o, [batch_size, num_head, seq_len, head_dim_v])
return topi.transpose(o, [0, 2, 1, 3])
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 4a708b5da1..8136997cf6 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -3727,6 +3727,41 @@ def test_dynamic_attention():
LegalizeOps()(Attention)
+def test_dynamic_batch_attention():
+ """The batch dimension may be dynamic (symbolic).
+
+ fix https://github.com/apache/tvm/issues/19696
+ """
+
+ @tvm.script.ir_module
+ class Attention:
+ @R.function
+ def main(
+ q: R.Tensor(("batch_size", 16, 32, 8), "float32"),
+ k: R.Tensor(("batch_size", 8, 32, 8), "float32"),
+ v: R.Tensor(("batch_size", 8, 32, 16), "float32"),
+ ):
+ gv = R.nn.attention(q, k, v)
+ return gv
+
+ LegalizeOps()(Attention)
+
+ @tvm.script.ir_module
+ class AttentionBias:
+ @R.function
+ def main(
+ q: R.Tensor(("batch_size", 16, 32, 8), "float32"),
+ k: R.Tensor(("batch_size", 8, 32, 8), "float32"),
+ v: R.Tensor(("batch_size", 8, 32, 16), "float32"),
+ bias: R.Tensor(("batch_size", 32, 16, 8), "float32"),
+ ):
+ scale = T.FloatImm("float32", 0.1)
+ gv = R.nn.attention(q, k, v, bias, scale=scale,
causal_mask="BottomRight")
+ return gv
+
+ LegalizeOps()(AttentionBias)
+
+
def test_nll_loss():
# fmt: off
@tvm.script.ir_module