LeiWang1999 opened a new pull request, #14449: URL: https://github.com/apache/tvm/pull/14449
In current cuda codegen, the `PrintMMAAssembly` instantiates the final asm code by replacing the A, B... in the template. However, this matching method is incorrect in some cases, as replacing a single letter can lead to confusion in the results. https://github.com/apache/tvm/blob/49e6695586d07c33c84097d2b0f58c79c2abd51e/src/target/source/ptx.cc#L564-L571 for example, the case I have meet: ``` { __asm__ __volatile__( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16" "{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" : "=r"(((unsigned *)(C_warp + ((ii_2 * 16) + (jj_2 * 8))))[0]), "=r"(((unsigned *)(C_warp + ((ii_2 * 16) + (jj_2 * 8))))[1]) : "r"(((unsigned *)(AC_warp + ((ii_2 * 16) + (jj_2 * 8))_shared_warp + (ii_2 * 8)))[0]), "r"(((unsigned *)(AC_warp + ((ii_2 * 16) + (jj_2 * 8))_shared_warp + (ii_2 * 8)))[1]), "r"(((unsigned *)(AC_warp + ((ii_2 * 16) + (jj_2 * 8))_shared_warp + (ii_2 * 8)))[2]), "r"(((unsigned *)(AC_warp + ((ii_2 * 16) + (jj_2 * 8))_shared_warp + (ii_2 * 8)))[3]), "r"(((unsigned *)(BC_warp + ((ii_2 * 16) + (jj_2 * 8))_shared_warp + (jj_2 * 8)))[0]), "r"(((unsigned *)(BC_warp + ((ii_2 * 16) + (jj_2 * 8))_shared_warp + (jj_2 * 8)))[1]), "r"(((unsigned *)(C_warp + ((ii_2 * 16) + (jj_2 * 8))))[0]), "r"(((unsigned *)(C_warp + ((ii_2 * 16) + (jj_2 * 8))))[1])); } ``` the source ptx and offset actually should be ``` { __asm__ __volatile__( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16" "{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" : "=r"(((unsigned *)(C_warp + ((ii_2 * 64) + (jj_2 * 8))))[0]), "=r"(((unsigned *)(C_warp + ((ii_2 * 64) + (jj_2 * 8))))[1]) : "r"(((unsigned *)(A_shared_warp + (ii_2 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (ii_2 * 8)))[1]), "r"(((unsigned *)(A_shared_warp + (ii_2 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (ii_2 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + (jj_2 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (jj_2 * 8)))[1]), "r"(((unsigned *)(C_warp + ((ii_2 * 64) + (jj_2 * 8))))[0]), "r"(((unsigned *)(C_warp + ((ii_2 * 64) + (jj_2 * 8))))[1])); } ``` By replacing the pattern "A" with "{A}", "B" with "{B}" .. , we can have a simple fix. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
