jcf94 opened a new pull request #5924: URL: https://github.com/apache/incubator-tvm/pull/5924
This pr is part of #5883 , fix for the rewrite_simplify error when doing vectorized cooperative fetching in some cases. Code generated with bug is shown like this: ``` A.shared[ramp(((ax0.ax1.fused.outer.outer*256) + (threadIdx.x_1*4)), 1, 4)] = (float32x4*)A_2[(((broadcast(((floordiv(blockIdx.x, 4)*32768) + (ax0.ax1.fused.outer.outer*2048)), 4) + (floordiv(ramp((threadIdx.x_1*4), 1, 4), broadcast(64, 4))*broadcast(512, 4))) + broadcast((k.outer.outer*64), 4)) + floormod(ramp((threadIdx.x_1*4), 1, 4), broadcast(64, 4)))]) ``` Which will finally lower to wrong CUDA C instructions. This should be simplified to generate the correct RampNode: ``` A.shared[ramp(((ax0.ax1.fused.outer.outer*256) + (threadIdx.x_1*4)), 1, 4)] = (float32x4*)A_2[ramp((((((floordiv(blockIdx.x, 4)*32768) + (ax0.ax1.fused.outer.outer*2048)) + (floordiv(threadIdx.x_1, 16)*512)) + (k.outer.outer*64)) + (floormod(threadIdx.x_1, 16)*4)), 1, 4)]) ``` Then main problems inside this expression are: ``` threadIdx.x_1 = [0, 64] floordiv(ramp((threadIdx.x_1*4), 1, 4), broadcast(64, 4)) * broadcast(512, 4) floormod(ramp((threadIdx.x_1*4), 1, 4), broadcast(64, 4)) ``` should be simplified to: ``` threadIdx.x_1 = [0, 64] broadcast(floordiv(threadIdx.x_1, 16)*512), 4) ramp(floormod(threadIdx.x_1, 16)*4, 1, 4) ``` ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
