csullivan opened a new pull request, #17082:
URL: https://github.com/apache/tvm/pull/17082
Change to use 16x32 spatial x reduction thread extents regardless of
workload size. This works around a lowering bug which I haven't tracked down
yet.
Currently when the spatial dimension is larger than the reduction dimension,
it uses a 4x64 thread layout. This implies two warps in the reduction dimension
corresponding to blockDim.x=64. An illegal cuda instruction is encountered in
the second warp during the __shfl_down_sync for the remainder portion of the
computation (as a result of the rfactor, I believe). AFAICT the mask
calculation used for this remainder shfl is incorrect and is causing the error.
Specifically it occurs on the first thread of the second warp (two warps along
x since blockDim.x = 64)
This is the relevant cuda causing the error:
```
if (((int)threadIdx.x) < 2) {
red_buf0[0] = red_buf_staging[((((int)threadIdx.y) * 2) +
((int)threadIdx.x))];
}
mask[0] = (__activemask() & ((uint)(3 << (((int)threadIdx.y) * 2)))); //
<<< likely the problem
t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 1, 32);
red_buf0[0] = (red_buf0[0] + t0[0]);
if (((int)threadIdx.x) == 0) {
((volatile half*)red_result)[((int)threadIdx.y)] = red_buf0[0];
}
```
The corresponding sass where the illegal instruction occurs:
```
0x00007d9e97b92490 <+1936>: WARPSYNC.ALL
0x00007d9e97b924a0 <+1952>: BAR.SYNC.DEFER_BLOCKING 0x0
0x00007d9e97b924b0 <+1968>: @!P1 VIADD R13, R5, 0x8
0x00007d9e97b924c0 <+1984>: @!P1 LEA R7, R17, R14, 0x1
0x00007d9e97b924d0 <+2000>: @!P1 PRMT R6, R2, 0x654, R13
0x00007d9e97b924e0 <+2016>: @!P1 LEA R7, R7, R6, 0x1
0x00007d9e97b924f0 <+2032>: @!P1 LDS.U16 R16, [R7]
0x00007d9e97b92500 <+2048>: IMAD.MOV.U32 R6, RZ, RZ, 0x3
0x00007d9e97b92510 <+2064>: SHF.L.U32 R17, R17, 0x1, RZ
0x00007d9e97b92520 <+2080>: VOTEU.ANY UR4, UPT, PT
0x00007d9e97b92530 <+2096>: SHF.L.U32 R3, R6, R17, RZ
0x00007d9e97b92540 <+2112>: LOP3.LUT R3, R3, UR4, RZ, 0xc0, !PT
0x00007d9e97b92550 <+2128>: ISETP.NE.AND P0, PT, R14, RZ, PT
0x00007d9e97b92560 <+2144>: PRMT R2, R2, 0x654, R5
0x00007d9e97b92570 <+2160>: PRMT R4, R16, 0x5410, R16
*> 0x00007d9e97b92580 <+2176>: WARPSYNC R3
=> 0x00007d9e97b92590 <+2192>: SHFL.DOWN PT, R3, R4, 0x1, 0x1f
0x00007d9e97b925a0 <+2208>: IMAD.IADD R17, R17, 0x1, R2
0x00007d9e97b925b0 <+2224>: HADD2 R16, R16.H0_H0, R3.H0_H0
0x00007d9e97b925c0 <+2240>: @!P0 STS.U16 [R17], R16
0x00007d9e97b925d0 <+2256>: WARPSYNC.ALL
0x00007d9e97b925e0 <+2272>: BAR.SYNC.DEFER_BLOCKING 0x0
```
Changing the thread extents to 16x32 (one warp along the reduction
dimension) works around the issue. It also improves performance for the tested
shapes by ~10%.
```
Utilizing (8, 2048, 4096) to avoid the error,
# 4x64
Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max
(ns) StdDev (ns) Name
-------- --------------- --------- -------- -------- --------
-------- ----------- --------------------------
81.5 612214 101 6061.5 6048.0 5920
7872 188.5 moe_dequantize_gemv_kernel
# 16x32
Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max
(ns) StdDev (ns) Name
-------- --------------- --------- -------- -------- --------
-------- ----------- --------------------------
79.9 555901 101 5504.0 5472.0 5439
6880 142.7 moe_dequantize_gemv_kernel
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]