LeiWang1999 opened a new pull request, #14329:
URL: https://github.com/apache/tvm/pull/14329

   This pull request adds support for the L2 prefetch option in the cp.async 
instruction, which is supported in CUDA 11.4 and later, with the support method 
referencing Cutlass. Additionally, this pull request adds support for 
asynchronous copying of if_then_else under vectorization, and fixes some bugs.
   
   for example, the original async cp can not support vectorized if_then_else, 
for a given template:
   
   ```python
   for ax0_ax1_0_fused_1 in T.thread_binding(2, thread="threadIdx.z"):
       for ax0_ax1_0_fused_2 in T.thread_binding(2, thread="threadIdx.y"):
           for ax0_ax1_0_fused_3 in T.thread_binding(32, thread="threadIdx.x"):
               with T.block("data_im2col_reindex_shared.dyn_o"):
                   v0 = T.axis.spatial(512, x_0_0 * 64 + (ax0_ax1_0_fused_0 * 
128 + ax0_ax1_0_fused_1 * 64 + ax0_ax1_0_fused_2 * 32 + ax0_ax1_0_fused_3) // 8)
                   v1_o = T.axis.spatial(1440, k_0_0 * 8 + (ax0_ax1_0_fused_0 * 
128 + ax0_ax1_0_fused_1 * 64 + ax0_ax1_0_fused_2 * 32 + ax0_ax1_0_fused_3) % 8)
                   T.reads(A[v0 // 256, v1_o // 480 + v0 % 256 // 16 - 1, v1_o 
% 480 // 160 + v0 % 16 - 1, v1_o % 160 * 8:v1_o % 160 * 8 + 8])
                   T.writes(data_im2col_reindex_shared_dyn[v0, v1_o * 8:v1_o * 
8 + 8])
                   for ax1_1 in T.vectorized(8):
                       with T.block("data_im2col_reindex_shared.dyn"):
                           v1_i = T.axis.spatial(8, ax1_1)
                           T.reads(A[v0 // 256, v1_o // 480 + v0 % 256 // 16 - 
1, v1_o % 480 // 160 + v0 % 16 - 1, v1_o % 160 * 8 + v1_i])
                           T.writes(data_im2col_reindex_shared_dyn[v0, v1_o * 8 
+ v1_i])
                           T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]]})
                           data_im2col_reindex_shared_dyn[v0, v1_o * 8 + v1_i] 
= T.if_then_else(1 <= v1_o // 480 + v0 % 256 // 16 and v1_o // 480 + v0 % 256 
// 16 < 17 and 1 <= v1_o % 480 // 160 + v0 % 16 and v1_o % 480 // 160 + v0 % 16 
< 17, A[v0 // 256, v1_o // 480 + v0 % 256 // 16 - 1, v1_o % 480 // 160 + v0 % 
16 - 1, v1_o % 160 * 8 + v1_i], T.float16(0))
   ```
   
   this pr will make the code into:
   
   ```c++
     {
       unsigned int addr;
       __asm__ __volatile__(
         "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; 
}"
         : "=r"(addr)
         : "l"((void *)(buf_dyn_shmem + (((((ax0_ax1_0_fused_0 * 2304) + 
(((int)threadIdx.z) * 1152)) + (((int)threadIdx.y) * 576)) + 
((((int)threadIdx.x) >> 3) * 144)) + ((((int)threadIdx.x) & 7) * 16))))
       );
       int pred_guard = (int)((((1 <= ((((((int)blockIdx.y) & 3) * 4) + (k_0_0 
/ 60)) + ax0_ax1_0_fused_0)) && (((((((int)blockIdx.y) & 3) * 4) + (k_0_0 / 
60)) + ax0_ax1_0_fused_0) < 17)) && (1 <= ((((((int)threadIdx.z) * 8) + 
(((int)threadIdx.y) * 4)) + ((k_0_0 % 60) / 20)) + (((int)threadIdx.x) >> 3)))) 
&& (((((((int)threadIdx.z) * 8) + (((int)threadIdx.y) * 4)) + ((k_0_0 % 60) / 
20)) + (((int)threadIdx.x) >> 3)) < 17));
       __asm__ __volatile__(
           "{  .reg .pred p;"
           "  setp.ne.b32 p, %0, 0;"
         #if TVM_ENABLE_L2_PREFETCH
           " @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;"
         #else
           " @p cp.async.ca.shared.global [%1], [%2], %3;"
         #endif
         "  @!p st.shared.v4.u32 [%1], {%4, %5, %6, %7};}"
           :: "r"(pred_guard), "r"(addr), "l"((void*)(A + 
(((((((((((int)blockIdx.y) * 81920) + ((k_0_0 / 60) * 20480)) + 
(ax0_ax1_0_fused_0 * 20480)) + (((int)threadIdx.z) * 10240)) + 
(((int)threadIdx.y) * 5120)) + ((((int)threadIdx.x) >> 3) * 1280)) + ((k_0_0 % 
60) * 64)) + ((((int)threadIdx.x) & 7) * 8)) - 21760))), "n"(16), "r"(0), 
"r"(0), "r"(0),"r"(0)
       );
     }
   ```
   
   while the original injectptxasync pass can only support for a serial, 4 
bytes aligned assignment of if_then_else, and the current code also has some 
bugs:
   
   ```c++
     std::string predicated_asm_code = R"(
     {
       unsigned int addr;
       __asm__ __volatile__(
         "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; 
}\n"
         : "=r"(addr)
         : "l"((void *)({smem_addr}))
       );
       int src_bytes = {pred_guard} ? {bytes} : 0;
       __asm__ __volatile__(
         "cp.async.{cg_or_ca}.shared.global [%0], [%1], %2, %3;"
          :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), "r"(src_bytes)
       );
     }
   )";
   ```
   
   if the condition is false, the code above will do nothing, however the right 
way is to assign zero to the shared memory address, or the value of memory is 
unpredictable, which will make the result uncorrected. we fixed by native 
shared memory copy asm:
   
   ```c++
       int pred_guard = (int){pred_guard};
       __asm__ __volatile__(
           "{  .reg .pred p;"
           "  setp.ne.b32 p, %0, 0;"
         #if TVM_ENABLE_L2_PREFETCH
           " @p cp.async.{cg_or_ca}.shared.global.L2::128B [%1], [%2], %3;"
         #else
           " @p cp.async.{cg_or_ca}.shared.global [%1], [%2], %3;"
         #endif
         "  @!p st.shared.v4.u32 [%1], {%4, %5, %6, %7};}"
           :: "r"(pred_guard), "r"(addr), "l"((void*)({global_ptr})), 
"n"({bytes}), "r"(0), "r"(0), "r"(0),"r"(0)
       );
   ```
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to