This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ea4369c221 [TIR] Add `T.thread_return()` for early thread exit in CUDA 
kernels (#18134)
ea4369c221 is described below

commit ea4369c221a26875c63d29aab435fc453d25b953
Author: Siyuan Feng <[email protected]>
AuthorDate: Mon Jul 14 21:13:19 2025 +0800

    [TIR] Add `T.thread_return()` for early thread exit in CUDA kernels (#18134)
    
    This commit implements T.thread_return() functionality that allows threads
    to exit early from CUDA kernels. The feature is useful for cases where
    threads need to conditionally return based on thread indices or other
    conditions.
    
    Key changes:
    - Add thread_return builtin in TIR
    - Implement CUDA codegen for thread_return
    - Add Python bindings for T.thread_return()
    - Update TIR IR builder to support thread_return
    - Add tests demonstrating thread_return usage
    
    Example usage:
    ```python
    @T.prim_func
    def main(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), 
"float32")):
        for i in T.thread_binding(16, thread="blockIdx.x"):
            for j in T.thread_binding(32, thread="threadIdx.x"):
                if j >= 16:
                    T.thread_return()  # Early exit for threads with j >= 16
                B[i, j] = A[i, j]
    ```
    
    and generate code is:
    
    ```cuda
    extern "C" __global__ void __launch_bounds__(32) main_kernel(float* 
__restrict__ A, float* __restrict__ B) {
      if (16 <= ((int)threadIdx.x)) {
        return;
      }
      B[((((int)blockIdx.x) * 16) + ((int)threadIdx.x))] = 
A[((((int)blockIdx.x) * 16) + ((int)threadIdx.x))];
    }
    ```
---
 include/tvm/tir/builtin.h                        |  4 ++++
 include/tvm/tir/op.h                             |  8 ++++++++
 python/tvm/script/ir_builder/tir/ir.py           |  2 ++
 python/tvm/tir/op.py                             | 17 +++++++++++++++++
 src/target/source/codegen_cuda.cc                |  2 ++
 src/tir/op/builtin.cc                            |  4 ++++
 src/tir/op/op.cc                                 |  6 ++++++
 tests/python/codegen/test_target_codegen_cuda.py | 17 +++++++++++++++++
 8 files changed, 60 insertions(+)

diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index c057422a02..6b31324fa5 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -45,6 +45,10 @@ namespace builtin {
  * \brief Return value.
  */
 TVM_DLL const Op& ret();
+/*!
+ * \brief Return from a GPU thread.
+ */
+TVM_DLL const Op& thread_return();
 /*!
  * \brief Reinterpret the value using the target type.
  */
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 99139f83b2..3dda3f7c63 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -91,6 +91,14 @@ TVM_DLL runtime::DataType GetRuntimeDataType(const Type& 
type);
  */
 TVM_DLL PrimExpr ret(PrimExpr value, Span span = Span());
 
+/*!
+ * \brief Return from a thread.
+ *
+ * \param span The location of this operation in the source.
+ * \return The return expression.
+ */
+TVM_DLL PrimExpr thread_return(Span span = Span());
+
 /*!
  * Query the maximum possible value of dtype.
  * \param dtype The data type.
diff --git a/python/tvm/script/ir_builder/tir/ir.py 
b/python/tvm/script/ir_builder/tir/ir.py
index 5864de2cac..c6549ad104 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1927,6 +1927,7 @@ sinh = _op_wrapper(_tir_op.sinh)
 sqrt = _op_wrapper(_tir_op.sqrt)
 tan = _op_wrapper(_tir_op.tan)
 tanh = _op_wrapper(_tir_op.tanh)
+thread_return = _op_wrapper(_tir_op.thread_return)
 trunc = _op_wrapper(_tir_op.trunc)
 truncdiv = _op_wrapper(_tir_op.truncdiv)
 truncmod = _op_wrapper(_tir_op.truncmod)
@@ -2205,6 +2206,7 @@ __all__ = float_types + [
     "sqrt",
     "tan",
     "tanh",
+    "thread_return",
     "trunc",
     "truncdiv",
     "truncmod",
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 155c7e10de..54c70ede7a 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -1882,6 +1882,23 @@ def ret(val, span=None):
     return _ffi_api.ret(val, span)
 
 
+def thread_return(span=None):
+    """Return from a GPU thread.
+
+    Parameters
+    ----------
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    ret : PrimExpr
+        The return expression
+    """
+
+    return _ffi_api.thread_return(span)
+
+
 def any(*args, span=None):
     """Create a new experssion of the union of all conditions in the arguments
 
diff --git a/src/target/source/codegen_cuda.cc 
b/src/target/source/codegen_cuda.cc
index 35e3d3cb8d..951415c3b3 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -1334,6 +1334,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, 
std::ostream& os) {
       LOG(FATAL) << "Invalid number of lanes for float4_e2m1fn reinterpret: " 
<< lanes;
     }
     EndScope(ssa_scope);
+  } else if (op->op.same_as(builtin::thread_return())) {
+    os << "return";
   } else {
     CodeGenC::VisitExpr_(op, os);
   }
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index 70614dfeeb..12c7c8d33c 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -48,6 +48,10 @@ TIR_DEFINE_BUILTIN_FUNC(ret)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kControlJump))
     .set_num_inputs(1);
 
+TIR_DEFINE_BUILTIN_FUNC(thread_return)
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kControlJump))
+    .set_num_inputs(0);
+
 TIR_DEFINE_BUILTIN_FUNC(likely)
     .set_num_inputs(1)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kExprAnnotation))
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 9a520073a6..55b89618a8 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -251,6 +251,12 @@ TVM_FFI_STATIC_INIT_BLOCK({
   refl::GlobalDef().def("tir.ret", ret);
 });
 
+PrimExpr thread_return(Span span) {
+  return tir::Call(DataType::Void(), tir::builtin::thread_return(), {}, span);
+}
+
+TVM_FFI_REGISTER_GLOBAL("tir.thread_return").set_body_typed(thread_return);
+
 // maximum and min limits
 PrimExpr max_value(const DataType& dtype, Span span) {
   using namespace tir;
diff --git a/tests/python/codegen/test_target_codegen_cuda.py 
b/tests/python/codegen/test_target_codegen_cuda.py
index 2d00618eb0..28dfb6b9d4 100644
--- a/tests/python/codegen/test_target_codegen_cuda.py
+++ b/tests/python/codegen/test_target_codegen_cuda.py
@@ -839,5 +839,22 @@ def test_device_host_call_same_func():
     tvm.testing.assert_allclose(c_tvm.numpy(), a_np + b_np)
 
 
[email protected]_cuda
+def test_thread_return():
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def main(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), 
"float32")):
+            for bx in T.thread_binding(32, "blockIdx.x"):
+                for tx in T.thread_binding(32, "threadIdx.x"):
+                    if bx >= 16 or tx >= 16:
+                        T.thread_return()
+                    B[bx, tx] = A[bx, tx]
+
+    lib = tvm.compile(Module, target="cuda")
+    cuda_code = lib.mod.imported_modules[0].get_source()
+    assert "return;" in cuda_code
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to