yzh119 opened a new pull request #10207:
URL: https://github.com/apache/tvm/pull/10207


   Previously the `LowerThreadAllReduce` pass will only emit code that uses 
`shfl_down` when reduce extent equals warp size, when reduce extent is less 
than warp size, the codegen fall back to emit code that uses shared memory, 
which is not efficient. Consider CUDA supports sub warp reduction by specifying 
the mask, we can still use shuffle-down approach for reduction by changing the 
mask.
   
   Example code:
   ```python
   import tvm
   import numpy as np
   from tvm.script import tir as T
   
   
   @T.prim_func
   def reduce(a: T.handle, b: T.handle) -> None:
       A = T.match_buffer(a, [1024, 11])
       B = T.match_buffer(b, [1024])
   
       for i, j in T.grid(1024, 11):
           with T.block("reduce"):
               vi, vj = T.axis.remap("SR", [i, j])
               with T.init():
                   B[vi] = 0.
               B[vi] = B[vi] + A[vi, vj]
   
   sch = tvm.tir.Schedule(reduce)
   blk = sch.get_block("reduce")
   i, j = sch.get_loops(blk)
   sch.bind(i, "blockIdx.x")
   sch.bind(j, "threadIdx.x")
   f = tvm.build(sch.mod["main"], target="cuda")
   print(f.imported_modules[0].get_source())
   ```
   
   
   Emitted code before this PR:
   ```cuda
   extern "C" __global__ void __launch_bounds__(11) 
default_function_kernel0(float* __restrict__ A, float* __restrict__ B) {
     __shared__ float red_buf0[11];
     __syncthreads();
     ((volatile float*)red_buf0)[(((int)threadIdx.x))] = A[(((((int)blockIdx.x) 
* 11) + ((int)threadIdx.x)))];
     __syncthreads();
     if (((int)threadIdx.x) < 3) {
       ((volatile float*)red_buf0)[(((int)threadIdx.x))] = (((volatile 
float*)red_buf0)[(((int)threadIdx.x))] + ((volatile 
float*)red_buf0)[((((int)threadIdx.x) + 8))]);
     }
     __syncthreads();
     if (((int)threadIdx.x) < 4) {
       float w_4_0 = (((volatile float*)red_buf0)[(((int)threadIdx.x))] + 
((volatile float*)red_buf0)[((((int)threadIdx.x) + 4))]);
       ((volatile float*)red_buf0)[(((int)threadIdx.x))] = w_4_0;
       float w_2_0 = (((volatile float*)red_buf0)[(((int)threadIdx.x))] + 
((volatile float*)red_buf0)[((((int)threadIdx.x) + 2))]);
       ((volatile float*)red_buf0)[(((int)threadIdx.x))] = w_2_0;
       float w_1_0 = (((volatile float*)red_buf0)[(((int)threadIdx.x))] + 
((volatile float*)red_buf0)[((((int)threadIdx.x) + 1))]);
       ((volatile float*)red_buf0)[(((int)threadIdx.x))] = w_1_0;
     }
     __syncthreads();
     B[(((int)blockIdx.x))] = ((volatile float*)red_buf0)[(0)];
   }
   ```
   
   Emitted code after this PR:
   ```cuda
   extern "C" __global__ void __launch_bounds__(11) 
default_function_kernel0(float* __restrict__ A, float* __restrict__ B) {
     float red_buf0[1];
     uint mask[1];
     float t0[1];
     red_buf0[(0)] = A[(((((int)blockIdx.x) * 11) + ((int)threadIdx.x)))];
     mask[(0)] = (__activemask() & (uint)2047);
     t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 8, 32);
     red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]);
     t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 4, 32);
     red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]);
     t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 2, 32);
     red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]);
     t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 1, 32);
     red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]);
     red_buf0[(0)] = __shfl_sync(mask[(0)], red_buf0[(0)], 0, 32);
     B[(((int)blockIdx.x))] = red_buf0[(0)];
   }
   ```
   
   # Future work
   CUDA 11 supports [cooperative group 
reduction](https://developer.nvidia.com/blog/cuda-11-features-revealed/) which 
we can directly use.
   
   cc @vinx13 @junrushao1994 
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to