wuyii8941 opened a new issue, #19576:
URL: https://github.com/apache/tvm/issues/19576
## Description
The `AdjustMatmulOrder` pass crashes with `Check failed: shape_c.size() == 2
(3 vs. 2)` when encountering a matmul chain where the intermediate result is 3D
and the final weight is 2D.
This is a common pattern in transformer models: `matmul(attn_output[B,S,D],
W_o[D,D])` where the attention output is 3D (batched) and the output projection
weight is 2D.
## Reproducer
```python
import numpy as np
import tvm
from tvm import relax
import tvm.relax.op as R
B, S, D = 2, 16, 64
bb = relax.BlockBuilder()
x = relax.Var('x', relax.TensorStructInfo((B, S, D), 'float32'))
wq = relax.Var('wq', relax.TensorStructInfo((D, D), 'float32'))
wk = relax.Var('wk', relax.TensorStructInfo((D, D), 'float32'))
wv = relax.Var('wv', relax.TensorStructInfo((D, D), 'float32'))
wo = relax.Var('wo', relax.TensorStructInfo((D, D), 'float32'))
with bb.function('main', [x, wq, wk, wv, wo]):
with bb.dataflow():
q = bb.emit(R.matmul(x, wq))
k = bb.emit(R.matmul(x, wk))
v = bb.emit(R.matmul(x, wv))
kt = bb.emit(R.permute_dims(k, [0, 2, 1]))
scores = bb.emit(R.matmul(q, kt))
scale = relax.const(1.0 / np.sqrt(D), 'float32')
scores = bb.emit(R.multiply(scores, scale))
attn = bb.emit(R.nn.softmax(scores, axis=-1))
out = bb.emit(R.matmul(attn, v)) # 3D result
proj = bb.emit_output(R.matmul(out, wo)) # 3D @ 2D → crash
bb.emit_func_output(proj)
mod = bb.finalize()
# This crashes:
pipeline = tvm.ir.transform.Sequential([
relax.transform.AdjustMatmulOrder(),
relax.transform.LegalizeOps()
])
mod_l = pipeline(mod) # Check failed: shape_c.size() == 2 (3 vs. 2)
```
## Error
```
tvm.error.InternalError: Check failed: shape_c.size() == 2 (3 vs. 2) :
```
## Root cause
The `AdjustMatmulOrder` pass assumes all operands in a matmul chain are 2D
(`shape.size() == 2`). When the intermediate result of `matmul(attn, v)`
produces a 3D tensor `[B, S, D]` that is then multiplied by a 2D weight `[D,
D]`, the pass fails the assertion.
## Expected behavior
The pass should either handle mixed-dimension matmul chains (3D @ 2D) or
skip them gracefully.
## Environment
- TVM version: 0.24.dev0 (commit 0b0afd8dd, 2026-04-24)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]