MoebiusMeow opened a new issue, #12274:
URL: https://github.com/apache/tvm/issues/12274

   Sub-warp reduction was introduced in PR #10207
   
   It turns out that the generated code directly uses return value from CUDA 
`__shfl_down_sync` when the target is an inactive thread outside the `mask`, 
which causes undefined behavior. (Although the value is always `0` during my 
test, it is not granted)
   
   In present unit tests, only `sum` reduction is tested, thus the `0` from 
inactive threads does nothing to the result.
   However, when it comes to other comm reduction such as `max`, it failed to 
obtain the correct answer, since the input values can be negative.
   
   Here is the test script. Based on example code in PR #10207
    
   ```python
   import tvm
   import numpy as np
   from tvm.script import tir as T
   
   @T.prim_func
   def reduce(a: T.handle, b: T.handle) -> None:
       A = T.match_buffer(a, [1024, 11])
       B = T.match_buffer(b, [1024])
   
       for i, j in T.grid(1024, 11):
           with T.block("test_reduce"):
               vi, vj = T.axis.remap("SR", [i, j]) 
               with T.init():
                   B[vi] = T.float32(-3.4028234663852886e38)
               B[vi] = T.max(B[vi], A[vi, vj])
   
   sch = tvm.tir.Schedule(reduce)
   blk = sch.get_block("test_reduce")
   i, j = sch.get_loops(blk)
   sch.bind(i, "blockIdx.x")
   sch.bind(j, "threadIdx.x")
   f = tvm.build(sch.mod["main"], target="cuda")
   print(f.imported_modules[0].get_source())
   
   def test_max():
       a_np = -np.random.rand(1024, 11).astype("float32")
       b_np = a_np.max(axis=-1).astype("float32")
       a = tvm.nd.array(a_np, tvm.cuda(0))
       b = tvm.nd.array(np.zeros_like(b_np), tvm.cuda(0))
   
       f(a, b)
       print(b)
       print("=" * 30) 
       print(b_np)
       assert np.allclose(b.numpy(), b_np, rtol=1e-6, atol=1e-6)
   
   test_max() 
   ```
   
   CUDA code:
   ```cuda
   extern "C" __global__ void __launch_bounds__(11) 
default_function_kernel0(float* __restrict__ A, float* __restrict__ B) {
     float red_buf0[1];
     uint mask[1];
     float t0[1];
     red_buf0[0] = A[((((int)blockIdx.x) * 11) + ((int)threadIdx.x))];
     mask[0] = __activemask();
     t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 8, 32);
     red_buf0[0] = max(red_buf0[0], t0[0]);
     t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 4, 32);
     red_buf0[0] = max(red_buf0[0], t0[0]);
     t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 2, 32);
     red_buf0[0] = max(red_buf0[0], t0[0]);
     t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 1, 32);
     red_buf0[0] = max(red_buf0[0], t0[0]);
     red_buf0[0] = __shfl_sync(mask[0], red_buf0[0], 0, 32);
     B[((int)blockIdx.x)] = red_buf0[0];
   }
   
   ```
   
   
   The statement `t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 8, 32);` can 
access `red_buf0[0]` in thread 8 to 18 while thread 11 to 18 are inactive.
   
   The `mask` in `__shfl_down_sync` only affects `sync`. The return value is 
still being used in the next line.
   
   ### Test result
   ```
   [0. 0. 0. ... 0. 0. 0.]
   ==============================
   [-0.1180084  -0.02599742 -0.30190793 ... -0.01682338 -0.0704228
    -0.05579288]
   Traceback (most recent call last):
     File "tmp.py", line 39, in <module>
       test_max()
     File "tmp.py", line 37, in test_max
       assert np.allclose(b.numpy(), b_np, rtol=1e-6, atol=1e-6)
   AssertionError
   ```
   I got a full `0` tensor output, rather than the max value of the negative 
inputs.
   
   ### Environment
   TVM main branch
   CUDA 10.2 (Also in CUDA 11)
   
   ### More
   This bug also causes tuned `Softmax` onnx op to return `nan` in my BERT 
model (since sub-warp reduction is usually fastest).
   
   
   
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to