ZQPei opened a new pull request #8566:
URL: https://github.com/apache/tvm/pull/8566


   This commit is to fix an illegal memory access bug in reduction op.
   
   ## Bug description
   Earlier last week I was running a tvm module with `cuda-memcheck` to make 
sure it was safe with memory process. However, the module failed to pass the 
memory check. Then I realized there must be something wrong with the CUDA 
kernels and finally I found it.
   
   To reproduce this error, you can simply run the following commands on a cuda 
enabled machine,
   ```bash
   cuda-memcheck --report-api-errors no python3 
${TVM_HOME}/tests/python/topi/python/test_topi_reduce.py
   ```
   and the terminal will show a stack trace like this,
   ```
   ...
   Running on target: llvm -device=arm_cpu                                      
                                                                                
                                                      
   Running on target: cuda                                                      
                                                                                
                                                      
   Running on target: llvm                                                      
                                                                                
                                                      
   Running on target: llvm -device=arm_cpu                                      
                                                                                
                                                      
   Running on target: cuda 
   ========= Invalid __global__ write of size 1                                 
                                                                                
                                                      
   =========     at 0x000003e0 in all_kernel0                                   
                                                                                
                                                      
   =========     by thread (0,23,0) in block (0,0,0)                            
                                                                                
                                                      
   =========     Address 0x7fb84f000217 is out of bounds                        
                                                                                
                                                      
   =========     Device Frame:all_kernel0 (all_kernel0 : 0x3e0)                 
                                                                                
                                                      
   =========     Saved host backtrace up to driver entry point at kernel launch 
time                                                                            
                                                      
   =========     Host Frame:/usr/lib/x86_64-linux-gnu/libcuda.so.1 
(cuLaunchKernel + 0x34e) [0x2e46de]                                             
                                                                   
   =========     Host 
Frame:/media/4T/workspace/pzq/Workspace/tvm/build/libtvm.so 
(_ZNK3tvm7runtime15CUDAWrappedFuncclENS0_7TVMArgsEPNS0_11TVMRetValueEPPv + 
0x181) [0x1dc4f81]                                                              
                                                                        
   ...
   ```
   This means that there are illegal memory accessed in reduction CUDA kernel.
   
   ## Bug analysis
   To solve this error, I wrote a simple python debug code as follows, which 
build and run a sum op and also it save the CUDA kernel at the same time,
   
   `test_reduce_sum.py`
   ```python
   import numpy as np
   import tvm
   from tvm import te
   from tvm import topi
   import tvm.testing
   import tvm.topi.testing
   
   def test_reduce_sum(in_shape, axis, keepdims, type="sum", dtype="int32"):
       # Sum expr
       A = te.placeholder(shape=in_shape, name="A", dtype=dtype)
       out_dtype = dtype
       B = topi.sum(A, axis=axis, keepdims=keepdims)
   
       device = "cuda"
       dev = tvm.cuda(0)
   
       with tvm.target.Target(device):
           s = topi.testing.get_reduce_schedule(device)(B)
   
       func = tvm.build(s, [A, B], device, name=type)
   
       # Data
       in_npy = np.random.randint(0, 256, size=in_shape).astype(dtype)
       out_npy = in_npy.sum(axis=axis, keepdims=keepdims)
       data_tvm = tvm.nd.array(in_npy, device=dev)
       out_tvm = tvm.nd.empty(shape=out_npy.shape, device=dev, dtype=out_dtype)
   
       # Run
       with open("lib_sum.cu", "w") as fo:
           fo.write(func.imported_modules[0].get_source())
   
       for _ in range(1):
           func(data_tvm, out_tvm)
       tvm.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1E-3, 1E-3)
   
   if __name__ == "__main__":
       test_reduce_sum(in_shape=(1, 32, 32, 1),
                       axis=(0, 1, 2),
                       keepdims=False,
                       type="sum")
   ```
   Also, I can reproduce the same memcheck error by running
   ```bash
   cuda-memcheck python3 test_reduce_sum.py
   ```
   and the CUDA kernel code in my simple `test_reduce_sum.py` will be saved to 
`lib_sum.cu`.
   ```cpp
   27 extern "C" __global__ void sum_kernel0(int* __restrict__ A, int* 
__restrict__ A_red) {
    28   int A_red_rf[1];
    29   int red_buf0[1];
    30   A_red_rf[(0)] = 0;
    31   for (int k0_k1_fused_k2_fused_outer = 0; k0_k1_fused_k2_fused_outer < 
32; ++k0_k1_fused_k2_fused_outer) {
    32     if (((int)threadIdx.y) < 1) {
    33       A_red_rf[(0)] = (A_red_rf[(0)] + A[((((k0_k1_fused_k2_fused_outer 
* 32) + ((int)threadIdx.x)) + ((int)threadIdx.y)))]);
    34     }
    35   }                                                                      
                                                                                
                                                      
    36   uint mask[1];
    37   int t0[1];
    38   red_buf0[(0)] = A_red_rf[(0)];
    39   mask[(0)] = __activemask();
    40   t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 16, 32);
    41   red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]);
    42   t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 8, 32);
    43   red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]);
    44   t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 4, 32);
    45   red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]);
    46   t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 2, 32);
    47   red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]);
    48   t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 1, 32);
    49   red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]);
    50   red_buf0[(0)] = __shfl_sync(mask[(0)], red_buf0[(0)], 0, 32);
    51   if (((int)threadIdx.x) == 0) {
    52     A_red[(((int)threadIdx.y))] = red_buf0[(0)];
    53   }
    54 }  
   
   ```
   Also, we can infer the kernel are launched with grid(1, 1, 1) and block(32, 
32, 1) from `python/tvm/topi/cuda/reduction.py`.
   From the CUDA kernel code and the error report, we can find that the code 
lacks a constriant to `threadIdx.y` at the end of buffer copy step.
   If the output size is less than 32, then `threadIdx.y` may access the 
illegal memory. The code from line 51 to line 53 should be like this,
   ```cpp
    51   if (((int)threadIdx.x) == 0 && (int)threadIdx.y) < 1) {
    52     A_red[(((int)threadIdx.y))] = red_buf0[(0)];
    53   }
   ```
   
   ## Fix the reduction schedule
   After analysising the CUDA kernel, we can fix the schedule of all reduction 
ops in `python/tvm/topi/cuda/reduction.py`.
   I amend the code in line 89 and add a constriant to thread_y by the 
following code
   ```python
   sch[real_output].set_store_predicate(
               tvm.tir.all(thread_x.equal(0),
                           block_x * num_thread + thread_y < reduce(mul, 
real_output.shape)))
   ```
   
   BTW, since this bug can only be detected with `cuda-memcheck` tool, I think 
it is essential to add `cuda-memcheck` to tvm Github Action to avoid bugs like 
this.
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to