This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 45ab5fb6dd [Relax][PyTorch] Fix InternalError when converting
scaled_dot_product_attention with 2D inputs (#18524)
45ab5fb6dd is described below
commit 45ab5fb6ddff9b03c72c4745557e6f602b97f3ba
Author: Dayuxiaoshui <[email protected]>
AuthorDate: Sun Nov 30 00:53:36 2025 +0800
[Relax][PyTorch] Fix InternalError when converting
scaled_dot_product_attention with 2D inputs (#18524)
Fixes #18441
Previously, the TVM frontend incorrectly assumed 4D input dimensions for
scaled_dot_product_attention, causing an InternalError when the actual
input was 2D (seq_len, head_dim).
This fix:
- Detects input dimensionality (2D vs 4D)
- For 2D inputs: expands to 4D, calls attention, then squeezes back
- For 4D inputs: maintains existing behavior
- Adds test case for 2D input scenario
- Updates verify_model_numerically to use strict=False for export
---------
Co-authored-by: Masahiro Hiramori <[email protected]>
---
.../frontend/torch/base_fx_graph_translator.py | 56 ++++++++++++++++++----
.../relax/test_frontend_from_exported_program.py | 39 +++++++++++++++
2 files changed, 87 insertions(+), 8 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 1938355169..e554648c41 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1477,10 +1477,49 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
return self.block_builder.emit(relax.op.nn.pixel_shuffle(data,
upscale_factor))
def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var:
- transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1,
3])
- query = transpose_S_H(self.env[node.args[0]])
- key = transpose_S_H(self.env[node.args[1]])
- value = transpose_S_H(self.env[node.args[2]])
+ query_tensor = self.env[node.args[0]]
+ key_tensor = self.env[node.args[1]]
+ value_tensor = self.env[node.args[2]]
+
+ # Check the dimensionality of the input tensors
+ query_ndim = len(query_tensor.struct_info.shape)
+
+ # TVM's nn.attention requires 4D inputs in format (batch, num_heads,
seq_len, head_dim)
+ # For 2D inputs (seq_len, head_dim), we need to reshape to 4D first
+ if query_ndim == 2:
+ # 2D input: (seq_len, head_dim) -> expand to (1, 1, seq_len,
head_dim)
+ # Add batch dimension at axis 0
+ query_3d =
self.block_builder.emit(relax.op.expand_dims(query_tensor, axis=0))
+ key_3d = self.block_builder.emit(relax.op.expand_dims(key_tensor,
axis=0))
+ value_3d =
self.block_builder.emit(relax.op.expand_dims(value_tensor, axis=0))
+ # Add num_heads dimension at axis 1
+ query = self.block_builder.emit(relax.op.expand_dims(query_3d,
axis=1))
+ key = self.block_builder.emit(relax.op.expand_dims(key_3d, axis=1))
+ value = self.block_builder.emit(relax.op.expand_dims(value_3d,
axis=1))
+
+ # No permutation needed for 2D inputs after expanding to 4D
+ # After attention, squeeze back to 2D: (1, 1, seq_len, head_dim)
-> (seq_len, head_dim)
+ def transpose_and_reshape_back(tensor):
+ # Squeeze batch and num_heads dimensions
+ return self.block_builder.emit(relax.op.squeeze(tensor,
axis=[0, 1]))
+
+ elif query_ndim == 4:
+ # 4D input: (batch, seq_len, num_heads, head_dim)
+ # -> (batch, num_heads, seq_len, head_dim)
+ transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0,
2, 1, 3])
+ query = self.block_builder.emit(transpose_S_H(query_tensor))
+ key = self.block_builder.emit(transpose_S_H(key_tensor))
+ value = self.block_builder.emit(transpose_S_H(value_tensor))
+
+ # For 4D, transpose back after attention
+ def transpose_and_reshape_back(tensor):
+ return self.block_builder.emit(transpose_S_H(tensor))
+
+ else:
+ raise ValueError(
+ f"scaled_dot_product_attention expects 2D or 4D inputs, but
got {query_ndim}D input"
+ )
+
attn_mask = node.args[3] if len(node.args) > 3 else
node.kwargs.get("attn_mask", None)
dropout_p = node.args[4] if len(node.args) > 4 else
node.kwargs.get("dropout_p", 0.0)
assert dropout_p == 0.0, "Dropout is not supported"
@@ -1492,12 +1531,12 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
msg = "Only a float mask is supported for the attn_mask input."
assert "float" in attn_mask.struct_info.dtype, msg
- return self.block_builder.emit(
- transpose_S_H(
- relax.op.nn.attention(query, key, value, bias=attn_mask,
causal_mask=causal_mask)
- )
+ attention_output = self.block_builder.emit(
+ relax.op.nn.attention(query, key, value, bias=attn_mask,
causal_mask=causal_mask)
)
+ return transpose_and_reshape_back(attention_output)
+
def _unbind(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
@@ -1594,6 +1633,7 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
x = args[0]
dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
keepdim = args[2] if len(node.args) > 2 else
node.kwargs.get("keepdim", False)
+
# For boolean tensors, any is equivalent to max (checking if any
element is True)
return self.block_builder.emit(relax.op.max(x, dim, keepdims=keepdim))
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 9d8ad67e12..662df5e76a 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -4294,6 +4294,45 @@ def test_scaled_dot_product_attention():
run_ep_decomposition=True,
)
+ # Test 2D input (seq_len, head_dim) - bug fix for #18441
+ class Attention2D(Module):
+ def forward(self, x):
+ return torch.nn.functional.scaled_dot_product_attention(x, x, x,
is_causal=False)
+
+ @I.ir_module
+ class Expected2D:
+ @R.function
+ def main(
+ x: R.Tensor((8, 32), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((8, 32), dtype="float32")):
+ with R.dataflow():
+ # Expand to add batch dimension for query, key, value
separately
+ # (8, 32) -> (1, 8, 32)
+ lv: R.Tensor((1, 8, 32), dtype="float32") = R.expand_dims(x,
axis=[0])
+ lv1: R.Tensor((1, 8, 32), dtype="float32") = R.expand_dims(x,
axis=[0])
+ lv2: R.Tensor((1, 8, 32), dtype="float32") = R.expand_dims(x,
axis=[0])
+ # Expand to add num_heads dimension: (1, 8, 32) -> (1, 1, 8,
32)
+ lv3: R.Tensor((1, 1, 8, 32), dtype="float32") =
R.expand_dims(lv, axis=[1])
+ lv4: R.Tensor((1, 1, 8, 32), dtype="float32") =
R.expand_dims(lv1, axis=[1])
+ lv5: R.Tensor((1, 1, 8, 32), dtype="float32") =
R.expand_dims(lv2, axis=[1])
+ # Attention operation: (1, 1, 8, 32) -> (1, 1, 8, 32)
+ lv6: R.Tensor((1, 1, 8, 32), dtype="float32") = R.nn.attention(
+ lv3, lv4, lv5, scale=None, causal_mask=None,
window_size=None
+ )
+ # Squeeze batch and num_heads dimensions: (1, 1, 8, 32) -> (8,
32)
+ lv7: R.Tensor((8, 32), dtype="float32") = R.squeeze(lv6,
axis=[0, 1])
+ gv: R.Tuple(R.Tensor((8, 32), dtype="float32")) = (lv7,)
+ R.output(gv)
+ return gv
+
+ verify_model(
+ Attention2D(),
+ (torch.randn(8, 32, dtype=torch.float32),),
+ {},
+ Expected2D,
+ run_ep_decomposition=False,
+ )
+
def test_unbind():
class Unbind1(Module):