https://github.com/kevinsala created 
https://github.com/llvm/llvm-project/pull/199483

Until now, strict behavior in the number of threads and blocks has been applied 
only when the kernel is in bare mode. When this mode is enabled, the values 
passed in UserNumBlocks and UserThreadLimit are not adjusted and are the 
definitive values used to launch the kernel. This commit detaches the 
strictness from the kernel mode.

This is going to be used by the kernel replay tool. Additionally, it starts 
clearing the path for the upcoming OpenMP dims modifier, used to configure 
multidimensional teams and leagues, which will include strictness choices for 
teams and threads.

All the bare kernels must indicate strict behavior. Asserts are added to check 
this condition.

>From 56b8d64b07d5710019223928518b432739162ad2 Mon Sep 17 00:00:00 2001
From: Kevin Sala <[email protected]>
Date: Tue, 12 May 2026 21:09:21 -0700
Subject: [PATCH] [offload][OpenMP] Add strict flags for blocks and threads in
 kernel arguments

Until now, strict behavior in the number of threads and blocks has been
applied only when the kernel is in bare mode. When this mode is enabled,
the values passed in UserNumBlocks and UserThreadLimit are not adjusted
and are the values used to launch the kernel. This commit detaches the
strictness from the kernel mode.

This is going to be used by the kernel replay tool. Additionally, it
paves the path for the upcoming OpenMP dims modifier, used to configure
multidimensional teams and leagues, which will include strictness choices
for teams and threads.

The bare kernels must indicate strict behavior. Asserts are added to
check this condition.
---
 clang/lib/CodeGen/CGOpenMPRuntime.cpp         |  3 +-
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  8 ++-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     |  7 ++-
 offload/include/Shared/APITypes.h             |  6 ++-
 offload/liboffload/src/OffloadImpl.cpp        |  1 +
 offload/libomptarget/KernelLanguage/API.cpp   |  1 +
 offload/libomptarget/omptarget.cpp            |  1 +
 .../common/include/PluginInterface.h          |  5 +-
 .../common/src/PluginInterface.cpp            | 49 ++++++++++---------
 9 files changed, 50 insertions(+), 31 deletions(-)

diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp 
b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index 60bfe3e9d43f7..2fc6128b78476 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -10984,7 +10984,8 @@ static void emitTargetCallKernelLaunch(
 
     llvm::OpenMPIRBuilder::TargetKernelArgs Args(
         NumTargetItems, RTArgs, NumIterations, NumTeams, NumThreads,
-        DynCGroupMem, HasNoWait, DynCGroupMemFallback);
+        DynCGroupMem, HasNoWait, /*StrictBlocksAndThreads=*/IsBare,
+        DynCGroupMemFallback);
 
     llvm::OpenMPIRBuilder::InsertPointTy AfterIP =
         cantFail(OMPRuntime->getOMPBuilder().emitKernelLaunch(
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h 
b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 2b790458f3c32..961b9958319a4 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2771,6 +2771,9 @@ class OpenMPIRBuilder {
     Value *DynCGroupMem = nullptr;
     /// True if the kernel has 'no wait' clause.
     bool HasNoWait = false;
+    /// True if the kernel strictly requires the number of blocks and threads
+    /// above to run.
+    bool StrictBlocksAndThreads = false;
     /// The fallback mechanism for the shared memory.
     omp::OMPDynGroupprivateFallbackType DynCGroupMemFallback =
         omp::OMPDynGroupprivateFallbackType::Abort;
@@ -2780,12 +2783,13 @@ class OpenMPIRBuilder {
     TargetKernelArgs(unsigned NumTargetItems, TargetDataRTArgs RTArgs,
                      Value *NumIterations, ArrayRef<Value *> NumTeams,
                      ArrayRef<Value *> NumThreads, Value *DynCGroupMem,
-                     bool HasNoWait,
+                     bool HasNoWait, bool StrictBlocksAndThreads,
                      omp::OMPDynGroupprivateFallbackType DynCGroupMemFallback)
         : NumTargetItems(NumTargetItems), RTArgs(RTArgs),
           NumIterations(NumIterations), NumTeams(NumTeams),
           NumThreads(NumThreads), DynCGroupMem(DynCGroupMem),
-          HasNoWait(HasNoWait), DynCGroupMemFallback(DynCGroupMemFallback) {}
+          HasNoWait(HasNoWait), StrictBlocksAndThreads(StrictBlocksAndThreads),
+          DynCGroupMemFallback(DynCGroupMemFallback) {}
   };
 
   /// Create the kernel args vector used by emitTargetKernel. This function
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp 
b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index ecff0c9b0aac4..1f2f3546f596e 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -647,7 +647,12 @@ void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs 
&KernelArgs,
   Value *DynCGroupMemFallbackFlag =
       Builder.getInt64(static_cast<uint64_t>(KernelArgs.DynCGroupMemFallback));
   DynCGroupMemFallbackFlag = Builder.CreateShl(DynCGroupMemFallbackFlag, 2);
+
+  Value *StrictFlag = Builder.getInt64(KernelArgs.StrictBlocksAndThreads);
+  StrictFlag = Builder.CreateShl(StrictFlag, 5);
+
   Value *Flags = Builder.CreateOr(HasNoWaitFlag, DynCGroupMemFallbackFlag);
+  Flags = Builder.CreateOr(Flags, StrictFlag);
 
   assert(!KernelArgs.NumTeams.empty() && !KernelArgs.NumThreads.empty());
 
@@ -9783,7 +9788,7 @@ static void emitTargetCall(
 
     KArgs = OpenMPIRBuilder::TargetKernelArgs(
         NumTargetItems, RTArgs, TripCount, NumTeamsC, NumThreadsC, 
DynCGroupMem,
-        HasNoWait, DynCGroupMemFallback);
+        HasNoWait, /*StrictBlocksAndThreads=*/false, DynCGroupMemFallback);
 
     // Assume no error was returned because TaskBodyCB and
     // EmitTargetCallFallbackCB don't produce any.
diff --git a/offload/include/Shared/APITypes.h 
b/offload/include/Shared/APITypes.h
index 212fb285030fb..99c5dcd3b5154 100644
--- a/offload/include/Shared/APITypes.h
+++ b/offload/include/Shared/APITypes.h
@@ -105,8 +105,10 @@ struct KernelArgsTy {
     uint64_t IsCUDA : 1; // Was this kernel spawned via CUDA.
     uint64_t DynCGroupMemFallback : 2; // The fallback for dynamic cgroup mem.
     uint64_t Cooperative : 1; // Was this kernel spawned as cooperative.
-    uint64_t Unused : 59;
-  } Flags = {0, 0, 0, 0, 0};
+    uint64_t StrictBlocksAndThreads
+        : 1; // The user-requested number of blocks and threads are strict.
+    uint64_t Unused : 58;
+  } Flags = {0, 0, 0, 0, 0, 0};
   // User-requested number of blocks (for x,y,z dimension).
   uint32_t UserNumBlocks[3] = {0, 0, 0};
   // User-requested number of threads (for x,y,z dimension).
diff --git a/offload/liboffload/src/OffloadImpl.cpp 
b/offload/liboffload/src/OffloadImpl.cpp
index 66fcbbc264ab4..de13fd7c67ee2 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -1122,6 +1122,7 @@ Error olLaunchKernel_impl(ol_queue_handle_t Queue, 
ol_device_handle_t Device,
   LaunchArgs.UserThreadLimit[1] = LaunchSizeArgs->GroupSize.y;
   LaunchArgs.UserThreadLimit[2] = LaunchSizeArgs->GroupSize.z;
   LaunchArgs.DynCGroupMem = LaunchSizeArgs->DynSharedMemory;
+  LaunchArgs.Flags.StrictBlocksAndThreads = true;
 
   while (Properties && Properties->type != OL_KERNEL_LAUNCH_PROP_TYPE_NONE) {
     switch (Properties->type) {
diff --git a/offload/libomptarget/KernelLanguage/API.cpp 
b/offload/libomptarget/KernelLanguage/API.cpp
index 112b27b707e5a..50f9b695bed6a 100644
--- a/offload/libomptarget/KernelLanguage/API.cpp
+++ b/offload/libomptarget/KernelLanguage/API.cpp
@@ -68,6 +68,7 @@ unsigned llvmLaunchKernel(const void *func, dim3 gridDim, 
dim3 blockDim,
   Args.UserThreadLimit[2] = blockDim.z;
   Args.ArgPtrs = reinterpret_cast<void **>(args);
   Args.Flags.IsCUDA = true;
+  Args.Flags.StrictBlocksAndThreads = true;
   return __tgt_target_kernel(nullptr, 0, gridDim.x, blockDim.x, func, &Args);
 }
 }
diff --git a/offload/libomptarget/omptarget.cpp 
b/offload/libomptarget/omptarget.cpp
index c2456920ebc1b..17b215732d51b 100644
--- a/offload/libomptarget/omptarget.cpp
+++ b/offload/libomptarget/omptarget.cpp
@@ -2481,6 +2481,7 @@ int target_replay(ident_t *Loc, DeviceTy &Device, void 
*HostPtr,
   KernelArgs.UserThreadLimit[1] = 1;
   KernelArgs.UserThreadLimit[2] = 1;
   KernelArgs.DynCGroupMem = SharedMemorySize;
+  KernelArgs.Flags.StrictBlocksAndThreads = true;
 
   KernelExtraArgsTy KernelExtraArgs{};
   KernelExtraArgs.ReplayOutcome = ReplayOutcome;
diff --git a/offload/plugins-nextgen/common/include/PluginInterface.h 
b/offload/plugins-nextgen/common/include/PluginInterface.h
index 54aac2f34b590..f99a0e817fd58 100644
--- a/offload/plugins-nextgen/common/include/PluginInterface.h
+++ b/offload/plugins-nextgen/common/include/PluginInterface.h
@@ -529,7 +529,7 @@ struct GenericKernelTy {
   /// Get the effective number of threads for the kernel based on the
   /// user-defined number of threads.
   uint32_t getEffectiveNumThreads(GenericDeviceTy &GenericDevice,
-                                  uint32_t UserThreadLimit[3]) const;
+                                  uint32_t UserThreadLimit) const;
 
   /// Get the effective number of blocks for the kernel based on the
   /// user-defined number of blocks and the loop trip count.
@@ -537,8 +537,7 @@ struct GenericKernelTy {
   /// \p IsNumThreadsFromUser is true is \p NumThreads is defined by user via
   /// thread_limit clause.
   uint32_t getEffectiveNumBlocks(GenericDeviceTy &GenericDevice,
-                                 uint32_t UserNumBlocks[3],
-                                 uint64_t LoopTripCount,
+                                 uint32_t UserNumBlocks, uint64_t 
LoopTripCount,
                                  uint32_t &EffectiveNumThreads,
                                  bool IsNumThreadsFromUser) const;
 
diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp 
b/offload/plugins-nextgen/common/src/PluginInterface.cpp
index 4379ebd250794..d3d80cad0d86a 100644
--- a/offload/plugins-nextgen/common/src/PluginInterface.cpp
+++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp
@@ -249,15 +249,27 @@ Error GenericKernelTy::launch(GenericDeviceTy 
&GenericDevice, void **ArgPtrs,
   uint32_t EffectiveNumBlocks[3] = {KernelArgs.UserNumBlocks[0],
                                     KernelArgs.UserNumBlocks[1],
                                     KernelArgs.UserNumBlocks[2]};
-  if (!isBareMode()) {
-    assert(
-        EffectiveNumThreads[1] == 1 && EffectiveNumThreads[2] == 1 &&
-        EffectiveNumBlocks[1] == 1 && EffectiveNumBlocks[2] == 1 &&
-        "Non-bare mode should only use the first thread and block dimensions");
+
+  // Multidimensional is only supported with bare mode for now.
+  assert(isBareMode() ||
+         EffectiveNumThreads[1] == 1 && EffectiveNumThreads[2] == 1 &&
+             EffectiveNumBlocks[1] == 1 && EffectiveNumBlocks[2] == 1 &&
+             "Non-bare mode should only use the first thread and block "
+             "dimensions");
+
+  assert(!KernelArgs.Flags.StrictBlocksAndThreads ||
+         EffectiveNumThreads[0] > 0 && EffectiveNumThreads[1] > 0 &&
+             EffectiveNumThreads[2] > 0 && EffectiveNumBlocks[0] > 0 &&
+             EffectiveNumBlocks[1] > 0 && EffectiveNumBlocks[2] > 0 &&
+             "Strict requires number of blocks and threads greater than zero");
+
+  // Calculate or adjust the effective number of threads and blocks if needed.
+  if (!KernelArgs.Flags.StrictBlocksAndThreads) {
     EffectiveNumThreads[0] =
-        getEffectiveNumThreads(GenericDevice, EffectiveNumThreads);
+        getEffectiveNumThreads(GenericDevice, EffectiveNumThreads[0]);
+
     EffectiveNumBlocks[0] = getEffectiveNumBlocks(
-        GenericDevice, EffectiveNumBlocks, KernelArgs.Tripcount,
+        GenericDevice, EffectiveNumBlocks[0], KernelArgs.Tripcount,
         EffectiveNumThreads[0], KernelArgs.UserThreadLimit[0] > 0);
   }
 
@@ -362,34 +374,27 @@ GenericKernelTy::prepareArgs(GenericDeviceTy 
&GenericDevice, void **ArgPtrs,
 
 uint32_t
 GenericKernelTy::getEffectiveNumThreads(GenericDeviceTy &GenericDevice,
-                                        uint32_t UserThreadLimit[3]) const {
+                                        uint32_t UserThreadLimit) const {
   assert(!isBareMode() && "bare kernel should not call this function");
 
-  assert(UserThreadLimit[1] == 1 && UserThreadLimit[2] == 1 &&
-         "Multi dimensional launch not supported yet.");
+  if (UserThreadLimit > 0 && isGenericMode())
+    UserThreadLimit += GenericDevice.getWarpSize();
 
-  if (UserThreadLimit[0] > 0 && isGenericMode())
-    UserThreadLimit[0] += GenericDevice.getWarpSize();
-
-  return std::min(MaxNumThreads, (UserThreadLimit[0] > 0)
-                                     ? UserThreadLimit[0]
-                                     : PreferredNumThreads);
+  return std::min(MaxNumThreads, (UserThreadLimit > 0) ? UserThreadLimit
+                                                       : PreferredNumThreads);
 }
 
 uint32_t GenericKernelTy::getEffectiveNumBlocks(
-    GenericDeviceTy &GenericDevice, uint32_t UserNumBlocks[3],
+    GenericDeviceTy &GenericDevice, uint32_t UserNumBlocks,
     uint64_t LoopTripCount, uint32_t &EffectiveNumThreads,
     bool IsNumThreadsFromUser) const {
   assert(!isBareMode() && "bare kernel should not call this function");
 
-  assert(UserNumBlocks[1] == 1 && UserNumBlocks[2] == 1 &&
-         "Multi dimensional launch not supported yet.");
-
-  if (UserNumBlocks[0] > 0) {
+  if (UserNumBlocks > 0) {
     // TODO: We need to honor any value and consequently allow more than the
     // block limit. For this we might need to start multiple kernels or let the
     // blocks start again until the requested number has been started.
-    return std::min(UserNumBlocks[0], GenericDevice.getBlockLimit());
+    return std::min(UserNumBlocks, GenericDevice.getBlockLimit());
   }
 
   // Return the number of blocks required to cover the loop iterations.

_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to