yzh119 opened a new issue #10210:
URL: https://github.com/apache/tvm/issues/10210


   # Behavior
   I found this bug when writing unit test for #10207, the following script 
could reproduce the bug:
   ```python
   import tvm
   import tvm.testing
   import numpy as np
   from tvm.script import tir as T
   
   
   @T.prim_func
   def reduce(a: T.handle, b: T.handle, d1: T.int32, d2: T.int32, d3: T.int32) 
-> None:
       A = T.match_buffer(a, [1, d1, d2, d3])
       B = T.match_buffer(b, [1, d1, d2])
   
       for i, j, k, l  in T.grid(1, d1, d2, d3):
           with T.block("reduce"):
               vi, vj, vk, vl = T.axis.remap("SSSR", [i, j, k, l])
               with T.init():
                   B[vi, vj, vk] = 0.
               B[vi, vj, vk] = B[vi, vj, vk] + A[vi, vj, vk, vl]
   
   _, _, _d1, _d2, _d3 = reduce.params
   mod = reduce.specialize(
       {_d1: 1, _d2: 16, _d3: 26}
   )
   sch = tvm.tir.Schedule(mod)
   blk = sch.get_block("reduce")
   i, j, k, l = sch.get_loops(blk)
   sch.bind(i, "blockIdx.x")
   sch.bind(j, "threadIdx.z")
   sch.bind(k, "threadIdx.y")
   sch.bind(l, "threadIdx.x")
   f = tvm.build(sch.mod["main"], target="cuda")
   print(f.imported_modules[0].get_source())
   
   # prepare input and output array
   a_np = np.random.rand(1, d1, d2, d3).astype("float32")
   b_np = a_np.sum(axis=-1).astype("float32")
   a = tvm.nd.array(a_np, tvm.cuda(0))
   b = tvm.nd.array(np.zeros_like(b_np), tvm.cuda(0))
   
   # launch kernel
   f(a, b)
   tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-6, atol=1e-6)
   ```
   
   By looking into the generated CUDA code:
   ```cuda
   #ifdef _WIN32
     using uint = unsigned int;
     using uchar = unsigned char;
     using ushort = unsigned short;
     using int64_t = long long;
     using uint64_t = unsigned long long;
   #else
     #define uint unsigned int
     #define uchar unsigned char
     #define ushort unsigned short
     #define int64_t long long
     #define uint64_t unsigned long long
   #endif
   extern "C" __global__ void __launch_bounds__(416) 
default_function_kernel0(float* __restrict__ A, float* __restrict__ B) {
     __shared__ float red_buf0[416];
     __syncthreads();
     ((volatile float*)red_buf0)[(((((int)threadIdx.y) * 26) + 
((int)threadIdx.x)))] = A[(((((int)threadIdx.y) * 26) + ((int)threadIdx.x)))];
     __syncthreads();
     if (((int)threadIdx.x) < 10) {
       ((volatile float*)red_buf0)[(((((int)threadIdx.y) * 26) + 
((int)threadIdx.x)))] = (((volatile float*)red_buf0)[(((((int)threadIdx.y) * 
26) + ((int)threadIdx.x)))] + ((volatile 
float*)red_buf0)[((((((int)threadIdx.y) * 26) + ((int)threadIdx.x)) + 16))]);
     }
     __syncthreads();
     if (((int)threadIdx.x) < 8) {
       float w_8_0 = (((volatile float*)red_buf0)[(((((int)threadIdx.y) * 26) + 
((int)threadIdx.x)))] + ((volatile float*)red_buf0)[((((((int)threadIdx.y) * 
26) + ((int)threadIdx.x)) + 8))]);
       ((volatile float*)red_buf0)[(((((int)threadIdx.y) * 26) + 
((int)threadIdx.x)))] = w_8_0;
       float w_4_0 = (((volatile float*)red_buf0)[(((((int)threadIdx.y) * 26) + 
((int)threadIdx.x)))] + ((volatile float*)red_buf0)[((((((int)threadIdx.y) * 
26) + ((int)threadIdx.x)) + 4))]);
       ((volatile float*)red_buf0)[(((((int)threadIdx.y) * 26) + 
((int)threadIdx.x)))] = w_4_0;
       float w_2_0 = (((volatile float*)red_buf0)[(((((int)threadIdx.y) * 26) + 
((int)threadIdx.x)))] + ((volatile float*)red_buf0)[((((((int)threadIdx.y) * 
26) + ((int)threadIdx.x)) + 2))]);
       ((volatile float*)red_buf0)[(((((int)threadIdx.y) * 26) + 
((int)threadIdx.x)))] = w_2_0;
       float w_1_0 = (((volatile float*)red_buf0)[(((((int)threadIdx.y) * 26) + 
((int)threadIdx.x)))] + ((volatile float*)red_buf0)[((((((int)threadIdx.y) * 
26) + ((int)threadIdx.x)) + 1))]);
       ((volatile float*)red_buf0)[(((((int)threadIdx.y) * 26) + 
((int)threadIdx.x)))] = w_1_0;
     }
     __syncthreads();
     B[(((int)threadIdx.y))] = ((volatile float*)red_buf0)[((((int)threadIdx.y) 
* 26))];
   }
   ```
   
   we can find this is because it didn't insert synchronize inside `if 
(((int)threadIdx.x) < 8) {`, because the compiler mistakenly assumes the 
reduction inside the region is performed in a single warp. However, this is not 
the case when there are `threadIdx.y/threadIdx.z`, and `blockDim.x` is not a 
multiple of warp size.
   
   This can be easily fixed by identifying whether the reduction is 
warp-aligned, I also fix it in #10207 .
   
   Generated code after the fix:
   ```cuda
   #ifdef _WIN32
     using uint = unsigned int;
     using uchar = unsigned char;
     using ushort = unsigned short;
     using int64_t = long long;
     using uint64_t = unsigned long long;
   #else
     #define uint unsigned int
     #define uchar unsigned char
     #define ushort unsigned short
     #define int64_t long long
     #define uint64_t unsigned long long
   #endif
   extern "C" __global__ void __launch_bounds__(416) 
default_function_kernel0(float* __restrict__ A, float* __restrict__ B) {
     __shared__ float red_buf0[416];
     __syncthreads();
     ((volatile float*)red_buf0)[(((((int)threadIdx.y) * 26) + 
((int)threadIdx.x)))] = A[(((((int)threadIdx.y) * 26) + ((int)threadIdx.x)))];
     __syncthreads();
     if (((int)threadIdx.x) < 10) {
       ((volatile float*)red_buf0)[(((((int)threadIdx.y) * 26) + 
((int)threadIdx.x)))] = (((volatile float*)red_buf0)[(((((int)threadIdx.y) * 
26) + ((int)threadIdx.x)))] + ((volatile 
float*)red_buf0)[((((((int)threadIdx.y) * 26) + ((int)threadIdx.x)) + 16))]);
     }
     __syncthreads();
     if (((int)threadIdx.x) < 8) {
       ((volatile float*)red_buf0)[(((((int)threadIdx.y) * 26) + 
((int)threadIdx.x)))] = (((volatile float*)red_buf0)[(((((int)threadIdx.y) * 
26) + ((int)threadIdx.x)))] + ((volatile 
float*)red_buf0)[((((((int)threadIdx.y) * 26) + ((int)threadIdx.x)) + 8))]);
     }
     __syncthreads();
     if (((int)threadIdx.x) < 4) {
       ((volatile float*)red_buf0)[(((((int)threadIdx.y) * 26) + 
((int)threadIdx.x)))] = (((volatile float*)red_buf0)[(((((int)threadIdx.y) * 
26) + ((int)threadIdx.x)))] + ((volatile 
float*)red_buf0)[((((((int)threadIdx.y) * 26) + ((int)threadIdx.x)) + 4))]);
     }
     __syncthreads();
     if (((int)threadIdx.x) < 2) {
       ((volatile float*)red_buf0)[(((((int)threadIdx.y) * 26) + 
((int)threadIdx.x)))] = (((volatile float*)red_buf0)[(((((int)threadIdx.y) * 
26) + ((int)threadIdx.x)))] + ((volatile 
float*)red_buf0)[((((((int)threadIdx.y) * 26) + ((int)threadIdx.x)) + 2))]);
     }
     __syncthreads();
     if (((int)threadIdx.x) < 1) {
       ((volatile float*)red_buf0)[(((((int)threadIdx.y) * 26) + 
((int)threadIdx.x)))] = (((volatile float*)red_buf0)[(((((int)threadIdx.y) * 
26) + ((int)threadIdx.x)))] + ((volatile 
float*)red_buf0)[((((((int)threadIdx.y) * 26) + ((int)threadIdx.x)) + 1))]);
     }
     __syncthreads();
     B[(((int)threadIdx.y))] = ((volatile float*)red_buf0)[((((int)threadIdx.y) 
* 26))];
   }
   ```
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to