Thrsu opened a new issue, #17357:
URL: https://github.com/apache/tvm/issues/17357

   I encountered an issue while running a Relax module with a specific 
transformation sequence. Specifically, when `FuseTIR()` is applied once, the VM 
fails to find the PackedFunc `fused_relax_nn_attention_cutlass_gv`. However, 
when the `FuseTIR()` optimization is applied again before 
`AllocateWorkspace()`, the problem disappears.
   
   ### Expected behavior
   
   The script is expected to run successfully without errors.
   
   ### Actual behavior
   
   InternalError: Check failed: (func.defined()) is false: Error: Cannot find 
PackedFunc fused_relax_nn_attention_cutlass_gv in either Relax VM kernel 
library, or in TVM runtime PackedFunc registry, or in global Relax functions of 
the VM executable
   
   ### Steps to reproduce
   
   <details>
   
   <summary>The following script reproduces the issue:</summary>
   
   ```python
   
   import tvm
   from tvm import relax
   
   from tvm.script import ir as I
   from tvm.script import tir as T
   from tvm.script import relax as R
   
   @I.ir_module
   class Module:
       @T.prim_func(private=True)
       def attention(q_1: T.Buffer((T.int64(32), T.int64(8), T.int64(16), 
T.int64(8)), "float16"), k_1: T.Buffer((T.int64(32), T.int64(8), T.int64(16), 
T.int64(8)), "float16"), v_1: T.Buffer((T.int64(32), T.int64(8), T.int64(16), 
T.int64(8)), "float16"), T_transpose: T.Buffer((T.int64(32), T.int64(8), 
T.int64(16), T.int64(8)), "float16")):
           T.func_attr({"tir.noalias": T.bool(True)})
           # with T.block("root"):
           T_transpose_1 = T.alloc_buffer((T.int64(32), T.int64(16), 
T.int64(8), T.int64(8)), "float16")
           T_reshape = T.alloc_buffer((T.int64(512), T.int64(8), T.int64(8)), 
"float16")
           T_transpose_2 = T.alloc_buffer((T.int64(32), T.int64(16), 
T.int64(8), T.int64(8)), "float16")
           T_reshape_1 = T.alloc_buffer((T.int64(512), T.int64(8), T.int64(8)), 
"float16")
           T_batch_matmul_NT = T.alloc_buffer((T.int64(512), T.int64(8), 
T.int64(8)), "float16")
           T_divide = T.alloc_buffer((T.int64(512), T.int64(8), T.int64(8)), 
"float16")
           T_softmax_maxelem = T.alloc_buffer((T.int64(512), T.int64(8)), 
"float16")
           T_softmax_exp = T.alloc_buffer((T.int64(512), T.int64(8), 
T.int64(8)), "float16")
           T_softmax_expsum = T.alloc_buffer((T.int64(512), T.int64(8)), 
"float16")
           T_softmax_norm = T.alloc_buffer((T.int64(512), T.int64(8), 
T.int64(8)), "float16")
           T_transpose_3 = T.alloc_buffer((T.int64(32), T.int64(16), 
T.int64(8), T.int64(8)), "float16")
           T_reshape_2 = T.alloc_buffer((T.int64(512), T.int64(8), T.int64(8)), 
"float16")
           T_batch_matmul_NN = T.alloc_buffer((T.int64(512), T.int64(8), 
T.int64(8)), "float16")
           T_reshape_3 = T.alloc_buffer((T.int64(32), T.int64(16), T.int64(8), 
T.int64(8)), "float16")
           for ax0, ax1, ax2, ax3 in T.grid(T.int64(32), T.int64(16), 
T.int64(8), T.int64(8)):
               with T.block("T_transpose"):
                   v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, 
ax2, ax3])
                   T.reads(q_1[v_ax0, v_ax2, v_ax1, v_ax3])
                   T.writes(T_transpose_1[v_ax0, v_ax1, v_ax2, v_ax3])
                   T_transpose_1[v_ax0, v_ax1, v_ax2, v_ax3] = q_1[v_ax0, 
v_ax2, v_ax1, v_ax3]
           for ax0, ax1, ax2 in T.grid(T.int64(512), T.int64(8), T.int64(8)):
               with T.block("T_reshape"):
                   v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                   T.reads(T_transpose_1[((v_ax2 // T.int64(8) + v_ax1) // 
T.int64(8) + v_ax0) % T.int64(512) // T.int64(16), ((v_ax2 // T.int64(8) + 
v_ax1) // T.int64(8) + v_ax0) % T.int64(16), (v_ax2 // T.int64(8) + v_ax1) % 
T.int64(8), v_ax2 % T.int64(8)])
                   T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
                   T_reshape[v_ax0, v_ax1, v_ax2] = T_transpose_1[((v_ax2 // 
T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(512) // T.int64(16), 
((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(16), (v_ax2 // 
T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)]
           for ax0, ax1, ax2, ax3 in T.grid(T.int64(32), T.int64(16), 
T.int64(8), T.int64(8)):
               with T.block("T_transpose_1"):
                   v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, 
ax2, ax3])
                   T.reads(k_1[v_ax0, v_ax2, v_ax1, v_ax3])
                   T.writes(T_transpose_2[v_ax0, v_ax1, v_ax2, v_ax3])
                   T_transpose_2[v_ax0, v_ax1, v_ax2, v_ax3] = k_1[v_ax0, 
v_ax2, v_ax1, v_ax3]
           for ax0, ax1, ax2 in T.grid(T.int64(512), T.int64(8), T.int64(8)):
               with T.block("T_reshape_1"):
                   v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                   T.reads(T_transpose_2[((v_ax2 // T.int64(8) + v_ax1) // 
T.int64(8) + v_ax0) % T.int64(512) // T.int64(16), ((v_ax2 // T.int64(8) + 
v_ax1) // T.int64(8) + v_ax0) % T.int64(16), (v_ax2 // T.int64(8) + v_ax1) % 
T.int64(8), v_ax2 % T.int64(8)])
                   T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2])
                   T_reshape_1[v_ax0, v_ax1, v_ax2] = T_transpose_2[((v_ax2 // 
T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(512) // T.int64(16), 
((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(16), (v_ax2 // 
T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)]
           for b, i, j, k in T.grid(T.int64(512), T.int64(8), T.int64(8), 
T.int64(8)):
               with T.block("T_batch_matmul_NT"):
                   v_b, v_i, v_j, v_k = T.axis.remap("SSSR", [b, i, j, k])
                   T.reads(T_reshape[v_b, v_i, v_k], T_reshape_1[v_b, v_j, v_k])
                   T.writes(T_batch_matmul_NT[v_b, v_i, v_j])
                   T.block_attr({"layout_free_placeholders": [T_reshape_1]})
                   with T.init():
                       T_batch_matmul_NT[v_b, v_i, v_j] = T.float16(0)
                   T_batch_matmul_NT[v_b, v_i, v_j] = T_batch_matmul_NT[v_b, 
v_i, v_j] + T_reshape[v_b, v_i, v_k] * T_reshape_1[v_b, v_j, v_k]
           for ax0, ax1, ax2 in T.grid(T.int64(512), T.int64(8), T.int64(8)):
               with T.block("T_divide"):
                   v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                   T.reads(T_batch_matmul_NT[v_ax0, v_ax1, v_ax2])
                   T.writes(T_divide[v_ax0, v_ax1, v_ax2])
                   T_divide[v_ax0, v_ax1, v_ax2] = T_batch_matmul_NT[v_ax0, 
v_ax1, v_ax2] / T.sqrt(T.float16(8))
           for i0, i1, k in T.grid(T.int64(512), T.int64(8), T.int64(8)):
               with T.block("T_softmax_maxelem"):
                   v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                   T.reads(T_divide[v_i0, v_i1, v_k])
                   T.writes(T_softmax_maxelem[v_i0, v_i1])
                   with T.init():
                       T_softmax_maxelem[v_i0, v_i1] = T.float16(-65504)
                   T_softmax_maxelem[v_i0, v_i1] = 
T.max(T_softmax_maxelem[v_i0, v_i1], T_divide[v_i0, v_i1, v_k])
           for i0, i1, i2 in T.grid(T.int64(512), T.int64(8), T.int64(8)):
               with T.block("T_softmax_exp"):
                   v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
                   T.reads(T_divide[v_i0, v_i1, v_i2], T_softmax_maxelem[v_i0, 
v_i1])
                   T.writes(T_softmax_exp[v_i0, v_i1, v_i2])
                   T_softmax_exp[v_i0, v_i1, v_i2] = T.exp(T_divide[v_i0, v_i1, 
v_i2] - T_softmax_maxelem[v_i0, v_i1])
           for i0, i1, k in T.grid(T.int64(512), T.int64(8), T.int64(8)):
               with T.block("T_softmax_expsum"):
                   v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                   T.reads(T_softmax_exp[v_i0, v_i1, v_k])
                   T.writes(T_softmax_expsum[v_i0, v_i1])
                   with T.init():
                       T_softmax_expsum[v_i0, v_i1] = T.float16(0)
                   T_softmax_expsum[v_i0, v_i1] = T_softmax_expsum[v_i0, v_i1] 
+ T_softmax_exp[v_i0, v_i1, v_k]
           for i0, i1, i2 in T.grid(T.int64(512), T.int64(8), T.int64(8)):
               with T.block("T_softmax_norm"):
                   v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
                   T.reads(T_softmax_exp[v_i0, v_i1, v_i2], 
T_softmax_expsum[v_i0, v_i1])
                   T.writes(T_softmax_norm[v_i0, v_i1, v_i2])
                   T.block_attr({"axis": 2})
                   T_softmax_norm[v_i0, v_i1, v_i2] = T_softmax_exp[v_i0, v_i1, 
v_i2] / T_softmax_expsum[v_i0, v_i1]
           for ax0, ax1, ax2, ax3 in T.grid(T.int64(32), T.int64(16), 
T.int64(8), T.int64(8)):
               with T.block("T_transpose_2"):
                   v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, 
ax2, ax3])
                   T.reads(v_1[v_ax0, v_ax2, v_ax1, v_ax3])
                   T.writes(T_transpose_3[v_ax0, v_ax1, v_ax2, v_ax3])
                   T_transpose_3[v_ax0, v_ax1, v_ax2, v_ax3] = v_1[v_ax0, 
v_ax2, v_ax1, v_ax3]
           for ax0, ax1, ax2 in T.grid(T.int64(512), T.int64(8), T.int64(8)):
               with T.block("T_reshape_2"):
                   v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                   T.reads(T_transpose_3[((v_ax2 // T.int64(8) + v_ax1) // 
T.int64(8) + v_ax0) % T.int64(512) // T.int64(16), ((v_ax2 // T.int64(8) + 
v_ax1) // T.int64(8) + v_ax0) % T.int64(16), (v_ax2 // T.int64(8) + v_ax1) % 
T.int64(8), v_ax2 % T.int64(8)])
                   T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2])
                   T_reshape_2[v_ax0, v_ax1, v_ax2] = T_transpose_3[((v_ax2 // 
T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(512) // T.int64(16), 
((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(16), (v_ax2 // 
T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)]
           for b, i, j, k in T.grid(T.int64(512), T.int64(8), T.int64(8), 
T.int64(8)):
               with T.block("T_batch_matmul_NN"):
                   v_b, v_i, v_j, v_k = T.axis.remap("SSSR", [b, i, j, k])
                   T.reads(T_softmax_norm[v_b, v_i, v_k], T_reshape_2[v_b, v_k, 
v_j])
                   T.writes(T_batch_matmul_NN[v_b, v_i, v_j])
                   T.block_attr({"layout_free_placeholders": [T_reshape_2]})
                   with T.init():
                       T_batch_matmul_NN[v_b, v_i, v_j] = T.float16(0)
                   T_batch_matmul_NN[v_b, v_i, v_j] = T_batch_matmul_NN[v_b, 
v_i, v_j] + T_softmax_norm[v_b, v_i, v_k] * T_reshape_2[v_b, v_k, v_j]
           for ax0, ax1, ax2, ax3 in T.grid(T.int64(32), T.int64(16), 
T.int64(8), T.int64(8)):
               with T.block("T_reshape_3"):
                   v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, 
ax2, ax3])
                   T.reads(T_batch_matmul_NN[(v_ax0 * T.int64(16) + (v_ax3 // 
T.int64(8) + v_ax2) // T.int64(8) + v_ax1) % T.int64(512), (v_ax3 // T.int64(8) 
+ v_ax2) % T.int64(8), v_ax3 % T.int64(8)])
                   T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3])
                   T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_batch_matmul_NN[(v_ax0 * T.int64(16) + (v_ax3 // T.int64(8) + v_ax2) // 
T.int64(8) + v_ax1) % T.int64(512), (v_ax3 // T.int64(8) + v_ax2) % T.int64(8), 
v_ax3 % T.int64(8)]
           for ax0, ax1, ax2, ax3 in T.grid(T.int64(32), T.int64(8), 
T.int64(16), T.int64(8)):
               with T.block("T_transpose_3"):
                   v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, 
ax2, ax3])
                   T.reads(T_reshape_3[v_ax0, v_ax2, v_ax1, v_ax3])
                   T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
                   T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_3[v_ax0, 
v_ax2, v_ax1, v_ax3]
   
       @R.function
       def entry_b(q: R.Tensor((32, 8, 16, 8), dtype="float16"), k: 
R.Tensor((32, 8, 16, 8), dtype="float16"), v: R.Tensor((32, 8, 16, 8), 
dtype="float16")) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
           cls = Module
           with R.dataflow():
               lv: R.Tensor((32, 8, 16, 8), dtype="float16") = 
cls.fused_relax_nn_attention_cutlass(q, k, v)
               R.output(lv)
           return lv
   
       @R.function
       def fused_relax_nn_attention_cutlass(q: R.Tensor((32, 8, 16, 8), 
dtype="float16"), k: R.Tensor((32, 8, 16, 8), dtype="float16"), v: 
R.Tensor((32, 8, 16, 8), dtype="float16")) -> R.Tensor((32, 8, 16, 8), 
dtype="float16"):
           R.func_attr({"Codegen": "cutlass", "WorkspaceSize": 65536})
           cls = Module
           
           @R.function
           def gv(q_1: R.Tensor((32, 8, 16, 8), dtype="float16"), k_1: 
R.Tensor((32, 8, 16, 8), dtype="float16"), v_1: R.Tensor((32, 8, 16, 8), 
dtype="float16")) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
               R.func_attr({"Composite": "cutlass.attention", "Primitive": 1, 
"WorkspaceSize": 65536})
               with R.dataflow():
                   gv_2 = R.call_tir(cls.attention, (q_1, k_1, v_1), 
out_sinfo=R.Tensor((32, 8, 16, 8), dtype="float16"))
                   R.output(gv_2)
               return gv_2
   
           gv1: R.Tensor((32, 8, 16, 8), dtype="float16") = gv(q, k, v)
           return gv1
   
   mod = Module
   
   # crash
   mod = tvm.transform.Sequential([relax.transform.FuseTIR(), 
relax.transform.LambdaLift(), relax.transform.AllocateWorkspace()])(mod)
   
   # pass
   #mod = tvm.transform.Sequential([relax.transform.FuseTIR(), 
relax.transform.LambdaLift(), relax.transform.FuseTIR(), 
relax.transform.AllocateWorkspace()])(mod)
   
   with tvm.transform.PassContext(opt_level=4):
       ex = relax.build(mod, target='llvm')
       vm = relax.VirtualMachine(ex, tvm.cpu())
   ```
   
   </details>
   
   Any guidance on whether this is a bug or a known order dependency would be 
greatly appreciated. 
   @Lunderberg 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to