Hzfengsy opened a new pull request, #18134:
URL: https://github.com/apache/tvm/pull/18134

   This commit implements T.thread_return() functionality that allows threads 
to exit early from CUDA kernels. The feature is useful for cases where threads 
need to conditionally return based on thread indices or other conditions.
   
   Key changes:
   - Add thread_return builtin in TIR
   - Implement CUDA codegen for thread_return
   - Add Python bindings for T.thread_return()
   - Update TIR IR builder to support thread_return
   - Add tests demonstrating thread_return usage
   
   Example usage:
   ```python
   @T.prim_func
   def main(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")):
       for i in T.thread_binding(16, thread="blockIdx.x"):
           for j in T.thread_binding(32, thread="threadIdx.x"):
               if j >= 16:
                   T.thread_return()  # Early exit for threads with j >= 16
               B[i, j] = A[i, j]
   ```
   
   and generate code is:
   
   ```cuda
   extern "C" __global__ void __launch_bounds__(32) main_kernel(float* 
__restrict__ A, float* __restrict__ B) {
     if (16 <= ((int)threadIdx.x)) {
       return;
     }
     B[((((int)blockIdx.x) * 16) + ((int)threadIdx.x))] = A[((((int)blockIdx.x) 
* 16) + ((int)threadIdx.x))];
   }
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to