This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 45ab5fb6dd [Relax][PyTorch] Fix InternalError when converting 
scaled_dot_product_attention with 2D inputs (#18524)
45ab5fb6dd is described below

commit 45ab5fb6ddff9b03c72c4745557e6f602b97f3ba
Author: Dayuxiaoshui <[email protected]>
AuthorDate: Sun Nov 30 00:53:36 2025 +0800

    [Relax][PyTorch] Fix InternalError when converting 
scaled_dot_product_attention with 2D inputs (#18524)
    
    Fixes #18441
    
    Previously, the TVM frontend incorrectly assumed 4D input dimensions for
    scaled_dot_product_attention, causing an InternalError when the actual
    input was 2D (seq_len, head_dim).
    
    This fix:
    - Detects input dimensionality (2D vs 4D)
    - For 2D inputs: expands to 4D, calls attention, then squeezes back
    - For 4D inputs: maintains existing behavior
    - Adds test case for 2D input scenario
    - Updates verify_model_numerically to use strict=False for export
    
    ---------
    
    Co-authored-by: Masahiro Hiramori <[email protected]>
---
 .../frontend/torch/base_fx_graph_translator.py     | 56 ++++++++++++++++++----
 .../relax/test_frontend_from_exported_program.py   | 39 +++++++++++++++
 2 files changed, 87 insertions(+), 8 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 1938355169..e554648c41 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1477,10 +1477,49 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         return self.block_builder.emit(relax.op.nn.pixel_shuffle(data, 
upscale_factor))
 
     def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var:
-        transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 
3])
-        query = transpose_S_H(self.env[node.args[0]])
-        key = transpose_S_H(self.env[node.args[1]])
-        value = transpose_S_H(self.env[node.args[2]])
+        query_tensor = self.env[node.args[0]]
+        key_tensor = self.env[node.args[1]]
+        value_tensor = self.env[node.args[2]]
+
+        # Check the dimensionality of the input tensors
+        query_ndim = len(query_tensor.struct_info.shape)
+
+        # TVM's nn.attention requires 4D inputs in format (batch, num_heads, 
seq_len, head_dim)
+        # For 2D inputs (seq_len, head_dim), we need to reshape to 4D first
+        if query_ndim == 2:
+            # 2D input: (seq_len, head_dim) -> expand to (1, 1, seq_len, 
head_dim)
+            # Add batch dimension at axis 0
+            query_3d = 
self.block_builder.emit(relax.op.expand_dims(query_tensor, axis=0))
+            key_3d = self.block_builder.emit(relax.op.expand_dims(key_tensor, 
axis=0))
+            value_3d = 
self.block_builder.emit(relax.op.expand_dims(value_tensor, axis=0))
+            # Add num_heads dimension at axis 1
+            query = self.block_builder.emit(relax.op.expand_dims(query_3d, 
axis=1))
+            key = self.block_builder.emit(relax.op.expand_dims(key_3d, axis=1))
+            value = self.block_builder.emit(relax.op.expand_dims(value_3d, 
axis=1))
+
+            # No permutation needed for 2D inputs after expanding to 4D
+            # After attention, squeeze back to 2D: (1, 1, seq_len, head_dim) 
-> (seq_len, head_dim)
+            def transpose_and_reshape_back(tensor):
+                # Squeeze batch and num_heads dimensions
+                return self.block_builder.emit(relax.op.squeeze(tensor, 
axis=[0, 1]))
+
+        elif query_ndim == 4:
+            # 4D input: (batch, seq_len, num_heads, head_dim)
+            # -> (batch, num_heads, seq_len, head_dim)
+            transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 
2, 1, 3])
+            query = self.block_builder.emit(transpose_S_H(query_tensor))
+            key = self.block_builder.emit(transpose_S_H(key_tensor))
+            value = self.block_builder.emit(transpose_S_H(value_tensor))
+
+            # For 4D, transpose back after attention
+            def transpose_and_reshape_back(tensor):
+                return self.block_builder.emit(transpose_S_H(tensor))
+
+        else:
+            raise ValueError(
+                f"scaled_dot_product_attention expects 2D or 4D inputs, but 
got {query_ndim}D input"
+            )
+
         attn_mask = node.args[3] if len(node.args) > 3 else 
node.kwargs.get("attn_mask", None)
         dropout_p = node.args[4] if len(node.args) > 4 else 
node.kwargs.get("dropout_p", 0.0)
         assert dropout_p == 0.0, "Dropout is not supported"
@@ -1492,12 +1531,12 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
             msg = "Only a float mask is supported for the attn_mask input."
             assert "float" in attn_mask.struct_info.dtype, msg
 
-        return self.block_builder.emit(
-            transpose_S_H(
-                relax.op.nn.attention(query, key, value, bias=attn_mask, 
causal_mask=causal_mask)
-            )
+        attention_output = self.block_builder.emit(
+            relax.op.nn.attention(query, key, value, bias=attn_mask, 
causal_mask=causal_mask)
         )
 
+        return transpose_and_reshape_back(attention_output)
+
     def _unbind(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
@@ -1594,6 +1633,7 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         x = args[0]
         dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
         keepdim = args[2] if len(node.args) > 2 else 
node.kwargs.get("keepdim", False)
+
         # For boolean tensors, any is equivalent to max (checking if any 
element is True)
         return self.block_builder.emit(relax.op.max(x, dim, keepdims=keepdim))
 
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 9d8ad67e12..662df5e76a 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -4294,6 +4294,45 @@ def test_scaled_dot_product_attention():
         run_ep_decomposition=True,
     )
 
+    # Test 2D input (seq_len, head_dim) - bug fix for #18441
+    class Attention2D(Module):
+        def forward(self, x):
+            return torch.nn.functional.scaled_dot_product_attention(x, x, x, 
is_causal=False)
+
+    @I.ir_module
+    class Expected2D:
+        @R.function
+        def main(
+            x: R.Tensor((8, 32), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((8, 32), dtype="float32")):
+            with R.dataflow():
+                # Expand to add batch dimension for query, key, value 
separately
+                # (8, 32) -> (1, 8, 32)
+                lv: R.Tensor((1, 8, 32), dtype="float32") = R.expand_dims(x, 
axis=[0])
+                lv1: R.Tensor((1, 8, 32), dtype="float32") = R.expand_dims(x, 
axis=[0])
+                lv2: R.Tensor((1, 8, 32), dtype="float32") = R.expand_dims(x, 
axis=[0])
+                # Expand to add num_heads dimension: (1, 8, 32) -> (1, 1, 8, 
32)
+                lv3: R.Tensor((1, 1, 8, 32), dtype="float32") = 
R.expand_dims(lv, axis=[1])
+                lv4: R.Tensor((1, 1, 8, 32), dtype="float32") = 
R.expand_dims(lv1, axis=[1])
+                lv5: R.Tensor((1, 1, 8, 32), dtype="float32") = 
R.expand_dims(lv2, axis=[1])
+                # Attention operation: (1, 1, 8, 32) -> (1, 1, 8, 32)
+                lv6: R.Tensor((1, 1, 8, 32), dtype="float32") = R.nn.attention(
+                    lv3, lv4, lv5, scale=None, causal_mask=None, 
window_size=None
+                )
+                # Squeeze batch and num_heads dimensions: (1, 1, 8, 32) -> (8, 
32)
+                lv7: R.Tensor((8, 32), dtype="float32") = R.squeeze(lv6, 
axis=[0, 1])
+                gv: R.Tuple(R.Tensor((8, 32), dtype="float32")) = (lv7,)
+                R.output(gv)
+            return gv
+
+    verify_model(
+        Attention2D(),
+        (torch.randn(8, 32, dtype=torch.float32),),
+        {},
+        Expected2D,
+        run_ep_decomposition=False,
+    )
+
 
 def test_unbind():
     class Unbind1(Module):

Reply via email to