wuyii8941 opened a new issue, #19576:
URL: https://github.com/apache/tvm/issues/19576

   
   ## Description
   
   The `AdjustMatmulOrder` pass crashes with `Check failed: shape_c.size() == 2 
(3 vs. 2)` when encountering a matmul chain where the intermediate result is 3D 
and the final weight is 2D.
   
   This is a common pattern in transformer models: `matmul(attn_output[B,S,D], 
W_o[D,D])` where the attention output is 3D (batched) and the output projection 
weight is 2D.
   
   ## Reproducer
   
   ```python
   import numpy as np
   import tvm
   from tvm import relax
   import tvm.relax.op as R
   
   B, S, D = 2, 16, 64
   bb = relax.BlockBuilder()
   x = relax.Var('x', relax.TensorStructInfo((B, S, D), 'float32'))
   wq = relax.Var('wq', relax.TensorStructInfo((D, D), 'float32'))
   wk = relax.Var('wk', relax.TensorStructInfo((D, D), 'float32'))
   wv = relax.Var('wv', relax.TensorStructInfo((D, D), 'float32'))
   wo = relax.Var('wo', relax.TensorStructInfo((D, D), 'float32'))
   with bb.function('main', [x, wq, wk, wv, wo]):
       with bb.dataflow():
           q = bb.emit(R.matmul(x, wq))
           k = bb.emit(R.matmul(x, wk))
           v = bb.emit(R.matmul(x, wv))
           kt = bb.emit(R.permute_dims(k, [0, 2, 1]))
           scores = bb.emit(R.matmul(q, kt))
           scale = relax.const(1.0 / np.sqrt(D), 'float32')
           scores = bb.emit(R.multiply(scores, scale))
           attn = bb.emit(R.nn.softmax(scores, axis=-1))
           out = bb.emit(R.matmul(attn, v))         # 3D result
           proj = bb.emit_output(R.matmul(out, wo))  # 3D @ 2D → crash
       bb.emit_func_output(proj)
   mod = bb.finalize()
   
   # This crashes:
   pipeline = tvm.ir.transform.Sequential([
       relax.transform.AdjustMatmulOrder(),
       relax.transform.LegalizeOps()
   ])
   mod_l = pipeline(mod)  # Check failed: shape_c.size() == 2 (3 vs. 2)
   ```
   
   ## Error
   
   ```
   tvm.error.InternalError: Check failed: shape_c.size() == 2 (3 vs. 2) :
   ```
   
   ## Root cause
   
   The `AdjustMatmulOrder` pass assumes all operands in a matmul chain are 2D 
(`shape.size() == 2`). When the intermediate result of `matmul(attn, v)` 
produces a 3D tensor `[B, S, D]` that is then multiplied by a 2D weight `[D, 
D]`, the pass fails the assertion.
   
   ## Expected behavior
   
   The pass should either handle mixed-dimension matmul chains (3D @ 2D) or 
skip them gracefully.
   
   ## Environment
   
   - TVM version: 0.24.dev0 (commit 0b0afd8dd, 2026-04-24)
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to