https://github.com/abidh updated 
https://github.com/llvm/llvm-project/pull/164655

>From 56037a64dbd5f73d2c020dd5d58d2c99758b35d0 Mon Sep 17 00:00:00 2001
From: Abid Qadeer <[email protected]>
Date: Tue, 21 Oct 2025 20:53:46 +0100
Subject: [PATCH 1/5] Add callback metadata to runtime functions which take
 callbacks.

---
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 25 ++++++++
 .../Frontend/OpenMPIRBuilderTest.cpp          | 58 +++++++++++++++++++
 2 files changed, 83 insertions(+)

diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp 
b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index c164d32f8f98c..312e119c4280d 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -750,6 +750,31 @@ OpenMPIRBuilder::getOrCreateRuntimeFunction(Module &M, 
RuntimeFunction FnID) {
             *MDNode::get(Ctx, {MDB.createCallbackEncoding(
                                   2, {-1, -1}, /* VarArgsArePassed */ true)}));
       }
+
+    } else if (FnID == OMPRTL___kmpc_distribute_static_loop_4 ||
+               FnID == OMPRTL___kmpc_distribute_static_loop_4u ||
+               FnID == OMPRTL___kmpc_distribute_static_loop_8 ||
+               FnID == OMPRTL___kmpc_distribute_static_loop_8u ||
+               FnID == OMPRTL___kmpc_distribute_for_static_loop_4 ||
+               FnID == OMPRTL___kmpc_distribute_for_static_loop_4u ||
+               FnID == OMPRTL___kmpc_distribute_for_static_loop_8 ||
+               FnID == OMPRTL___kmpc_distribute_for_static_loop_8u ||
+               FnID == OMPRTL___kmpc_for_static_loop_4 ||
+               FnID == OMPRTL___kmpc_for_static_loop_4u ||
+               FnID == OMPRTL___kmpc_for_static_loop_8 ||
+               FnID == OMPRTL___kmpc_for_static_loop_8u) {
+      if (!Fn->hasMetadata(LLVMContext::MD_callback)) {
+        LLVMContext &Ctx = Fn->getContext();
+        MDBuilder MDB(Ctx);
+        // Annotate the callback behavior of the runtime function:
+        //  - The callback callee is argument number 1.
+        //  - The first argument of the callback callee is unknown (-1).
+        //  - The second argument of the callback callee is argument number 2
+        Fn->addMetadata(
+            LLVMContext::MD_callback,
+            *MDNode::get(Ctx, {MDB.createCallbackEncoding(
+                                  1, {-1, 2}, /* VarArgsArePassed */ false)}));
+      }
     }
 
     LLVM_DEBUG(dbgs() << "Created OpenMP runtime function " << Fn->getName()
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp 
b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index d231a778a8a97..aca2153f85c26 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -7957,4 +7957,62 @@ TEST_F(OpenMPIRBuilderTest, spliceBBWithEmptyBB) {
   EXPECT_FALSE(Terminator->getDbgRecordRange().empty());
 }
 
+TEST_F(OpenMPIRBuilderTest, callBackFunctions) {
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.Config.IsTargetDevice = true;
+  OMPBuilder.initialize();
+
+  // Test multiple runtime functions that should have callback metadata
+  std::vector<RuntimeFunction> CallbackFunctions = {
+    OMPRTL___kmpc_distribute_static_loop_4,
+    OMPRTL___kmpc_distribute_static_loop_4u,
+    OMPRTL___kmpc_distribute_static_loop_8,
+    OMPRTL___kmpc_distribute_static_loop_8u,
+    OMPRTL___kmpc_distribute_for_static_loop_4,
+    OMPRTL___kmpc_distribute_for_static_loop_4u,
+    OMPRTL___kmpc_distribute_for_static_loop_8,
+    OMPRTL___kmpc_distribute_for_static_loop_8u,
+    OMPRTL___kmpc_for_static_loop_4,
+    OMPRTL___kmpc_for_static_loop_4u,
+    OMPRTL___kmpc_for_static_loop_8,
+    OMPRTL___kmpc_for_static_loop_8u
+  };
+
+  for (RuntimeFunction RF : CallbackFunctions) {
+    Function *Fn = OMPBuilder.getOrCreateRuntimeFunctionPtr(RF);
+    ASSERT_NE(Fn, nullptr) << "Function should exist for runtime function";
+    
+    MDNode *CallbackMD = Fn->getMetadata(LLVMContext::MD_callback);
+    EXPECT_NE(CallbackMD, nullptr) << "Function should have callback metadata";
+    
+    if (CallbackMD) {
+      // Should have at least one callback
+      EXPECT_GE(CallbackMD->getNumOperands(), 1U);
+      
+      // Test first callback entry
+      MDNode *FirstCallback = cast<MDNode>(CallbackMD->getOperand(0));
+      EXPECT_EQ(FirstCallback->getNumOperands(), 4U);
+      
+      // Callee index should be valid
+      auto *CalleeIdxCM = 
cast<ConstantAsMetadata>(FirstCallback->getOperand(0));
+      uint64_t CalleeIdx = 
cast<ConstantInt>(CalleeIdxCM->getValue())->getZExtValue();
+      EXPECT_EQ(CalleeIdx, 1u);
+
+      // Verify payload arguments re (-1, 2)
+      auto *Arg0CM = cast<ConstantAsMetadata>(FirstCallback->getOperand(1));
+      int64_t Arg0 = cast<ConstantInt>(Arg0CM->getValue())->getSExtValue();
+      EXPECT_EQ(Arg0, -1);
+      auto *Arg1CM = cast<ConstantAsMetadata>(FirstCallback->getOperand(2));
+      int64_t Arg1 = cast<ConstantInt>(Arg1CM->getValue())->getSExtValue();
+      EXPECT_EQ(Arg1, 2);
+
+      // Verify the varArgs is false.
+      auto *VarArgCM = cast<ConstantAsMetadata>(FirstCallback->getOperand(3));
+      uint64_t VarArg = 
cast<ConstantInt>(VarArgCM->getValue())->getZExtValue();
+      EXPECT_EQ(VarArg, 0u);
+    }
+  }
+}
+
+
 } // namespace

>From ba23f45c80f4179a20ca6a2e36f7c75870d0bc0b Mon Sep 17 00:00:00 2001
From: Abid Qadeer <[email protected]>
Date: Tue, 21 Oct 2025 21:10:52 +0100
Subject: [PATCH 2/5] Allow SPMDisation of function taking callbacks.

---
 llvm/lib/Transforms/IPO/OpenMPOpt.cpp         | 72 ++++++++++++++-
 .../test/Transforms/OpenMP/callback_guards.ll | 91 +++++++++++++++++++
 2 files changed, 159 insertions(+), 4 deletions(-)
 create mode 100644 llvm/test/Transforms/OpenMP/callback_guards.ll

diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp 
b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index 20fcb7307ff7d..56b49507603eb 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -3659,6 +3659,46 @@ struct AAKernelInfo : public 
StateWrapper<KernelInfoState, AbstractAttribute> {
   static const char ID;
 };
 
+static bool isRuntimeFuncionWithCallbacks(RuntimeFunction RF) {
+  switch (RF) {
+  case OMPRTL___kmpc_distribute_static_loop_4:
+  case OMPRTL___kmpc_distribute_static_loop_4u:
+  case OMPRTL___kmpc_distribute_static_loop_8:
+  case OMPRTL___kmpc_distribute_static_loop_8u:
+  case OMPRTL___kmpc_distribute_for_static_loop_4:
+  case OMPRTL___kmpc_distribute_for_static_loop_4u:
+  case OMPRTL___kmpc_distribute_for_static_loop_8:
+  case OMPRTL___kmpc_distribute_for_static_loop_8u:
+  case OMPRTL___kmpc_for_static_loop_4:
+  case OMPRTL___kmpc_for_static_loop_4u:
+  case OMPRTL___kmpc_for_static_loop_8:
+  case OMPRTL___kmpc_for_static_loop_8u:
+    return true;
+  default:
+    return false;
+  }
+}
+
+static unsigned runtimeFuncionCallbackArgNo(RuntimeFunction RF) {
+  switch (RF) {
+  case OMPRTL___kmpc_distribute_static_loop_4:
+  case OMPRTL___kmpc_distribute_static_loop_4u:
+  case OMPRTL___kmpc_distribute_static_loop_8:
+  case OMPRTL___kmpc_distribute_static_loop_8u:
+  case OMPRTL___kmpc_distribute_for_static_loop_4:
+  case OMPRTL___kmpc_distribute_for_static_loop_4u:
+  case OMPRTL___kmpc_distribute_for_static_loop_8:
+  case OMPRTL___kmpc_distribute_for_static_loop_8u:
+  case OMPRTL___kmpc_for_static_loop_4:
+  case OMPRTL___kmpc_for_static_loop_4u:
+  case OMPRTL___kmpc_for_static_loop_8:
+  case OMPRTL___kmpc_for_static_loop_8u:
+    return 1;
+  default:
+    llvm_unreachable("Unexpected runtime function!");
+  }
+}
+
 /// The function kernel info abstract attribute, basically, what can we say
 /// about a function with regards to the KernelInfoState.
 struct AAKernelInfoFunction : AAKernelInfo {
@@ -4752,6 +4792,30 @@ struct AAKernelInfoFunction : AAKernelInfo {
     bool AllSPMDStatesWereFixed = true;
     auto CheckCallInst = [&](Instruction &I) {
       auto &CB = cast<CallBase>(I);
+      auto &OMPInfoCache = static_cast<OMPInformationCache 
&>(A.getInfoCache());
+      const auto &It =
+          OMPInfoCache.RuntimeFunctionIDMap.find(CB.getCalledFunction());
+      if (It != OMPInfoCache.RuntimeFunctionIDMap.end()) {
+        RuntimeFunction RF = It->getSecond();
+        if (isRuntimeFuncionWithCallbacks(RF)) {
+          const unsigned int CallbackArgNo = runtimeFuncionCallbackArgNo(RF);
+          auto *LoopRegion = dyn_cast<Function>(
+              CB.getArgOperand(CallbackArgNo)->stripPointerCasts());
+          if (LoopRegion) {
+            auto *FnAA = A.getAAFor<AAKernelInfo>(
+                *this, IRPosition::function(*LoopRegion), 
DepClassTy::OPTIONAL);
+            if (FnAA) {
+              getState() ^= FnAA->getState();
+              AllSPMDStatesWereFixed &=
+                  FnAA->SPMDCompatibilityTracker.isAtFixpoint();
+              AllParallelRegionStatesWereFixed &=
+                  FnAA->ReachedKnownParallelRegions.isAtFixpoint();
+              AllParallelRegionStatesWereFixed &=
+                  FnAA->ReachedUnknownParallelRegions.isAtFixpoint();
+            }
+          }
+        }
+      }
       auto *CBAA = A.getAAFor<AAKernelInfo>(
           *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
       if (!CBAA)
@@ -4927,7 +4991,7 @@ struct AAKernelInfoCallSite : AAKernelInfo {
         // state based on the callee state in updateImpl.
         return;
       }
-      if (NumCallees > 1) {
+      if (NumCallees > 1 && !isRuntimeFuncionWithCallbacks(It->getSecond())) {
         indicatePessimisticFixpoint();
         return;
       }
@@ -5043,8 +5107,8 @@ struct AAKernelInfoCallSite : AAKernelInfo {
         // kernel from being SPMD-izable. We mark it as such because we need
         // further changes in order to also consider the contents of the
         // callbacks passed to them.
-        SPMDCompatibilityTracker.indicatePessimisticFixpoint();
-        SPMDCompatibilityTracker.insert(&CB);
+        // SPMDCompatibilityTracker.indicatePessimisticFixpoint();
+        // SPMDCompatibilityTracker.insert(&CB);
         break;
       default:
         // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
@@ -5097,7 +5161,7 @@ struct AAKernelInfoCallSite : AAKernelInfo {
         getState() = FnAA->getState();
         return ChangeStatus::CHANGED;
       }
-      if (NumCallees > 1)
+      if (NumCallees > 1 && !isRuntimeFuncionWithCallbacks(It->getSecond()))
         return indicatePessimisticFixpoint();
 
       CallBase &CB = cast<CallBase>(getAssociatedValue());
diff --git a/llvm/test/Transforms/OpenMP/callback_guards.ll 
b/llvm/test/Transforms/OpenMP/callback_guards.ll
new file mode 100644
index 0000000000000..09343fa196d9d
--- /dev/null
+++ b/llvm/test/Transforms/OpenMP/callback_guards.ll
@@ -0,0 +1,91 @@
+; RUN: opt -passes=openmp-opt -S < %s | FileCheck %s
+
+%struct.ident_t = type { i32, i32, i32, i32, ptr }
+%struct.DynamicEnvironmentTy = type { i16 }
+%struct.KernelEnvironmentTy = type { %struct.ConfigurationEnvironmentTy, ptr, 
ptr }
+%struct.ConfigurationEnvironmentTy = type { i8, i8, i8, i32, i32, i32, i32, 
i32, i32 }
+
+@0 = private unnamed_addr addrspace(1) constant [23 x i8] 
c";unknown;unknown;0;0;;\00", align 1
+@1 = private unnamed_addr addrspace(1) constant %struct.ident_t { i32 0, i32 
2, i32 0, i32 22, ptr addrspacecast (ptr addrspace(1) @0 to ptr) }, align 8
+@__omp_offloading_10303_1849aab__QQmain_l22_exec_mode = weak protected 
addrspace(1) constant i8 1
+@__omp_offloading_10303_1849aab__QQmain_l22_dynamic_environment = weak_odr 
protected addrspace(1) global %struct.DynamicEnvironmentTy zeroinitializer
+@__omp_offloading_10303_1849aab__QQmain_l22_kernel_environment = weak_odr 
protected addrspace(1) constant %struct.KernelEnvironmentTy { 
%struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 1, i32 1, i32 256, i32 0, 
i32 0, i32 4, i32 1024 }, ptr addrspacecast (ptr addrspace(1) @1 to ptr), ptr 
addrspacecast (ptr addrspace(1) 
@__omp_offloading_10303_1849aab__QQmain_l22_dynamic_environment to ptr) }
+
+; Function Attrs: nounwind
+define internal void @parallel_func_..omp_par.3(ptr noalias noundef 
%tid.addr.ascast, ptr noalias noundef %zero.addr.ascast, ptr %0) #1 {
+omp.par.entry:
+  ret void
+}
+
+; Function Attrs: mustprogress
+define weak_odr protected amdgpu_kernel void 
@__omp_offloading_10303_1849aab__QQmain_l22(ptr %0, ptr %1, ptr %2) #4 {
+entry:
+  %7 = call i32 @__kmpc_target_init(ptr addrspacecast (ptr addrspace(1) 
@__omp_offloading_10303_1849aab__QQmain_l22_kernel_environment to ptr), ptr %0)
+  %exec_user_code = icmp eq i32 %7, -1
+  br i1 %exec_user_code, label %user_code.entry, label %worker.exit
+
+user_code.entry:                                  ; preds = %entry
+  call void @__kmpc_distribute_static_loop_4u(ptr addrspacecast (ptr 
addrspace(1) @1 to ptr), ptr 
@__omp_offloading_10303_1849aab__QQmain_l22..omp_par, ptr %2, i32 100, i32 0, 
i8 0)
+  call void @__kmpc_target_deinit()
+  br label %worker.exit
+
+worker.exit:                                      ; preds = %entry
+  ret void
+}
+
+
+define internal void @__omp_offloading_10303_1849aab__QQmain_l22..omp_par(i32 
%0, ptr %1) {
+omp_loop.body:
+  %gep = getelementptr { ptr, ptr }, ptr %1, i32 0, i32 1
+  %p = load ptr, ptr %gep, align 8
+  %5 = add i32 %0, 1
+  store i32 %5, ptr %p, align 4
+  %omp_global_thread_num = call i32 @__kmpc_global_thread_num(ptr 
addrspacecast (ptr addrspace(1) @1 to ptr))
+  call void @__kmpc_parallel_51(ptr addrspacecast (ptr addrspace(1) @1 to 
ptr), i32 %omp_global_thread_num, i32 1, i32 -1, i32 -1, ptr 
@parallel_func_..omp_par.3, ptr @parallel_func_..omp_par.3.wrapper, ptr %1, i64 
1)
+  %6 = load i32, ptr %p, align 4
+  %7 = add i32 %6, 1
+  store i32 %7, ptr %p, align 4
+  ret void
+}
+
+define internal void @parallel_func_..omp_par.3.wrapper(i16 noundef zeroext 
%0, i32 noundef %1) {
+entry:
+  %addr = alloca i32, align 4, addrspace(5)
+  %addr.ascast = addrspacecast ptr addrspace(5) %addr to ptr
+  %zero = alloca i32, align 4, addrspace(5)
+  %zero.ascast = addrspacecast ptr addrspace(5) %zero to ptr
+  %global_args = alloca ptr, align 8, addrspace(5)
+  %global_args.ascast = addrspacecast ptr addrspace(5) %global_args to ptr
+  store i32 %1, ptr %addr.ascast, align 4
+  store i32 0, ptr %zero.ascast, align 4
+  call void @__kmpc_get_shared_variables(ptr %global_args.ascast)
+  %2 = load ptr, ptr %global_args.ascast, align 8
+  %3 = getelementptr inbounds ptr, ptr %2, i64 0
+  %structArg = load ptr, ptr %3, align 8
+  call void @parallel_func_..omp_par.3(ptr %addr.ascast, ptr %zero.ascast, ptr 
%structArg)
+  ret void
+}
+
+
+declare void @__kmpc_get_shared_variables(ptr)
+declare i32 @__kmpc_target_init(ptr, ptr)
+declare noalias ptr @__kmpc_alloc_shared(i64)
+declare void @__kmpc_target_deinit()
+declare i32 @__kmpc_global_thread_num(ptr)
+declare void @__kmpc_parallel_51(ptr, i32, i32, i32, i32, ptr, ptr, ptr, i64)
+declare !callback !12 void @__kmpc_distribute_static_loop_4u(ptr, ptr, ptr, 
i32, i32, i8)
+
+attributes #1 = { nounwind "frame-pointer"="all" }
+attributes #4 = { "kernel" }
+
+!llvm.module.flags = !{!0, !1}
+
+!0 = !{i32 7, !"openmp-device", i32 52}
+!1 = !{i32 7, !"openmp", i32 52}
+!12 = !{!13}
+!13 = !{i64 1, i64 -1, i64 -1, i1 false}
+
+; CHECK: define internal void 
@__omp_offloading_10303_1849aab__QQmain_l22..omp_par(
+; CHECK: region.guarded:
+; CHECK: region.guarded{{[0-9]+}}:
+; CHECK: ret void

>From afb76ff88f4c01539beb0c0ebe0553edf0c09a7b Mon Sep 17 00:00:00 2001
From: Abid Qadeer <[email protected]>
Date: Wed, 22 Oct 2025 17:46:11 +0100
Subject: [PATCH 3/5] Fix formatting issues.

---
 .../Frontend/OpenMPIRBuilderTest.cpp          | 40 +++++++++----------
 1 file changed, 20 insertions(+), 20 deletions(-)

diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp 
b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index aca2153f85c26..376cee8d708e8 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -7964,38 +7964,39 @@ TEST_F(OpenMPIRBuilderTest, callBackFunctions) {
 
   // Test multiple runtime functions that should have callback metadata
   std::vector<RuntimeFunction> CallbackFunctions = {
-    OMPRTL___kmpc_distribute_static_loop_4,
-    OMPRTL___kmpc_distribute_static_loop_4u,
-    OMPRTL___kmpc_distribute_static_loop_8,
-    OMPRTL___kmpc_distribute_static_loop_8u,
-    OMPRTL___kmpc_distribute_for_static_loop_4,
-    OMPRTL___kmpc_distribute_for_static_loop_4u,
-    OMPRTL___kmpc_distribute_for_static_loop_8,
-    OMPRTL___kmpc_distribute_for_static_loop_8u,
-    OMPRTL___kmpc_for_static_loop_4,
-    OMPRTL___kmpc_for_static_loop_4u,
-    OMPRTL___kmpc_for_static_loop_8,
-    OMPRTL___kmpc_for_static_loop_8u
-  };
+      OMPRTL___kmpc_distribute_static_loop_4,
+      OMPRTL___kmpc_distribute_static_loop_4u,
+      OMPRTL___kmpc_distribute_static_loop_8,
+      OMPRTL___kmpc_distribute_static_loop_8u,
+      OMPRTL___kmpc_distribute_for_static_loop_4,
+      OMPRTL___kmpc_distribute_for_static_loop_4u,
+      OMPRTL___kmpc_distribute_for_static_loop_8,
+      OMPRTL___kmpc_distribute_for_static_loop_8u,
+      OMPRTL___kmpc_for_static_loop_4,
+      OMPRTL___kmpc_for_static_loop_4u,
+      OMPRTL___kmpc_for_static_loop_8,
+      OMPRTL___kmpc_for_static_loop_8u};
 
   for (RuntimeFunction RF : CallbackFunctions) {
     Function *Fn = OMPBuilder.getOrCreateRuntimeFunctionPtr(RF);
     ASSERT_NE(Fn, nullptr) << "Function should exist for runtime function";
-    
+
     MDNode *CallbackMD = Fn->getMetadata(LLVMContext::MD_callback);
     EXPECT_NE(CallbackMD, nullptr) << "Function should have callback metadata";
-    
+
     if (CallbackMD) {
       // Should have at least one callback
       EXPECT_GE(CallbackMD->getNumOperands(), 1U);
-      
+
       // Test first callback entry
       MDNode *FirstCallback = cast<MDNode>(CallbackMD->getOperand(0));
       EXPECT_EQ(FirstCallback->getNumOperands(), 4U);
-      
+
       // Callee index should be valid
-      auto *CalleeIdxCM = 
cast<ConstantAsMetadata>(FirstCallback->getOperand(0));
-      uint64_t CalleeIdx = 
cast<ConstantInt>(CalleeIdxCM->getValue())->getZExtValue();
+      auto *CalleeIdxCM =
+          cast<ConstantAsMetadata>(FirstCallback->getOperand(0));
+      uint64_t CalleeIdx =
+          cast<ConstantInt>(CalleeIdxCM->getValue())->getZExtValue();
       EXPECT_EQ(CalleeIdx, 1u);
 
       // Verify payload arguments re (-1, 2)
@@ -8014,5 +8015,4 @@ TEST_F(OpenMPIRBuilderTest, callBackFunctions) {
   }
 }
 
-
 } // namespace

>From 331a07bca2370116ab3e6cc82a57efa5ba8132ea Mon Sep 17 00:00:00 2001
From: Abid Qadeer <[email protected]>
Date: Wed, 22 Oct 2025 18:55:14 +0100
Subject: [PATCH 4/5] Handle kmpc_reduction_get_fixed_buffer.

---
 llvm/lib/Transforms/IPO/OpenMPOpt.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp 
b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index 56b49507603eb..24550dad49d30 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -5012,6 +5012,7 @@ struct AAKernelInfoCallSite : AAKernelInfo {
       case OMPRTL___kmpc_barrier:
       case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
       case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
+      case OMPRTL___kmpc_reduction_get_fixed_buffer:
       case OMPRTL___kmpc_error:
       case OMPRTL___kmpc_flush:
       case OMPRTL___kmpc_get_hardware_thread_id_in_block:

>From 6de5124d17015fc3c192f7d8b4c23b3498ec25d3 Mon Sep 17 00:00:00 2001
From: Abid Qadeer <[email protected]>
Date: Wed, 22 Oct 2025 18:56:21 +0100
Subject: [PATCH 5/5] Use metadata instead of hardcoding function IDs.

Also set metadata in case it was not set (can happen with linked
library functions).
---
 llvm/lib/Transforms/IPO/OpenMPOpt.cpp | 106 +++++++++++++++-----------
 1 file changed, 61 insertions(+), 45 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp 
b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index 24550dad49d30..4f64f3844148e 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -50,6 +50,7 @@
 #include "llvm/IR/IntrinsicsAMDGPU.h"
 #include "llvm/IR/IntrinsicsNVPTX.h"
 #include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/MDBuilder.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -544,6 +545,46 @@ struct OMPInformationCache : public InformationCache {
     collectUses(RFI, /*CollectStats*/ false);
   }
 
+  void initializeMetadata(Function *F, RuntimeFunction RF) {
+    if (!F)
+      return;
+
+    // Check if this is one of the runtime functions that needs callback
+    // metadata
+    switch (RF) {
+    case OMPRTL___kmpc_distribute_static_loop_4:
+    case OMPRTL___kmpc_distribute_static_loop_4u:
+    case OMPRTL___kmpc_distribute_static_loop_8:
+    case OMPRTL___kmpc_distribute_static_loop_8u:
+    case OMPRTL___kmpc_distribute_for_static_loop_4:
+    case OMPRTL___kmpc_distribute_for_static_loop_4u:
+    case OMPRTL___kmpc_distribute_for_static_loop_8:
+    case OMPRTL___kmpc_distribute_for_static_loop_8u:
+    case OMPRTL___kmpc_for_static_loop_4:
+    case OMPRTL___kmpc_for_static_loop_4u:
+    case OMPRTL___kmpc_for_static_loop_8:
+    case OMPRTL___kmpc_for_static_loop_8u: {
+      // Only add metadata if it doesn't already exist
+      if (!F->hasMetadata(LLVMContext::MD_callback)) {
+        LLVMContext &Ctx = F->getContext();
+        MDBuilder MDB(Ctx);
+        // Annotate the callback behavior of the runtime function:
+        //  - The callback callee is argument number 1.
+        //  - The first argument of the callback callee is unknown (-1).
+        //  - The second argument of the callback callee is argument number 2
+        F->addMetadata(
+            LLVMContext::MD_callback,
+            *MDNode::get(Ctx, {MDB.createCallbackEncoding(
+                                  1, {-1, 2}, /* VarArgsArePassed */ false)}));
+      }
+      break;
+    }
+    default:
+      // No metadata needed for other runtime functions
+      break;
+    }
+  }
+
   // Helper function to recollect uses of all runtime functions.
   void recollectUses() {
     for (int Idx = 0; Idx < RFIs.size(); ++Idx)
@@ -615,6 +656,7 @@ struct OMPInformationCache : public InformationCache {
       RFI.ReturnType = OMPBuilder._ReturnType;                                 
\
       RFI.ArgumentTypes = std::move(ArgsTypes);                                
\
       RFI.Declaration = F;                                                     
\
+      initializeMetadata(F, _Enum);                                            
\
       unsigned NumUses = collectUses(RFI);                                     
\
       (void)NumUses;                                                           
\
       LLVM_DEBUG({                                                             
\
@@ -3659,44 +3701,19 @@ struct AAKernelInfo : public 
StateWrapper<KernelInfoState, AbstractAttribute> {
   static const char ID;
 };
 
-static bool isRuntimeFuncionWithCallbacks(RuntimeFunction RF) {
-  switch (RF) {
-  case OMPRTL___kmpc_distribute_static_loop_4:
-  case OMPRTL___kmpc_distribute_static_loop_4u:
-  case OMPRTL___kmpc_distribute_static_loop_8:
-  case OMPRTL___kmpc_distribute_static_loop_8u:
-  case OMPRTL___kmpc_distribute_for_static_loop_4:
-  case OMPRTL___kmpc_distribute_for_static_loop_4u:
-  case OMPRTL___kmpc_distribute_for_static_loop_8:
-  case OMPRTL___kmpc_distribute_for_static_loop_8u:
-  case OMPRTL___kmpc_for_static_loop_4:
-  case OMPRTL___kmpc_for_static_loop_4u:
-  case OMPRTL___kmpc_for_static_loop_8:
-  case OMPRTL___kmpc_for_static_loop_8u:
-    return true;
-  default:
-    return false;
-  }
+static bool isRuntimeFunctionWithCallbacks(Function *Fn) {
+  return Fn->hasMetadata(LLVMContext::MD_callback);
 }
 
-static unsigned runtimeFuncionCallbackArgNo(RuntimeFunction RF) {
-  switch (RF) {
-  case OMPRTL___kmpc_distribute_static_loop_4:
-  case OMPRTL___kmpc_distribute_static_loop_4u:
-  case OMPRTL___kmpc_distribute_static_loop_8:
-  case OMPRTL___kmpc_distribute_static_loop_8u:
-  case OMPRTL___kmpc_distribute_for_static_loop_4:
-  case OMPRTL___kmpc_distribute_for_static_loop_4u:
-  case OMPRTL___kmpc_distribute_for_static_loop_8:
-  case OMPRTL___kmpc_distribute_for_static_loop_8u:
-  case OMPRTL___kmpc_for_static_loop_4:
-  case OMPRTL___kmpc_for_static_loop_4u:
-  case OMPRTL___kmpc_for_static_loop_8:
-  case OMPRTL___kmpc_for_static_loop_8u:
-    return 1;
-  default:
-    llvm_unreachable("Unexpected runtime function!");
-  }
+static unsigned runtimeFunctionCallbackArgNo(Function *Fn) {
+  assert(isRuntimeFunctionWithCallbacks(Fn));
+  MDNode *CallbackMD = Fn->getMetadata(LLVMContext::MD_callback);
+  assert(CallbackMD);
+  MDNode *OpMD = cast<MDNode>(CallbackMD->getOperand(0).get());
+  auto *CalleeIdxCM = cast<ConstantAsMetadata>(OpMD->getOperand(0));
+  uint64_t CalleeIdx =
+      cast<ConstantInt>(CalleeIdxCM->getValue())->getZExtValue();
+  return CalleeIdx;
 }
 
 /// The function kernel info abstract attribute, basically, what can we say
@@ -4793,14 +4810,13 @@ struct AAKernelInfoFunction : AAKernelInfo {
     auto CheckCallInst = [&](Instruction &I) {
       auto &CB = cast<CallBase>(I);
       auto &OMPInfoCache = static_cast<OMPInformationCache 
&>(A.getInfoCache());
-      const auto &It =
-          OMPInfoCache.RuntimeFunctionIDMap.find(CB.getCalledFunction());
+      Function *Callee = CB.getCalledFunction();
+      const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
       if (It != OMPInfoCache.RuntimeFunctionIDMap.end()) {
-        RuntimeFunction RF = It->getSecond();
-        if (isRuntimeFuncionWithCallbacks(RF)) {
-          const unsigned int CallbackArgNo = runtimeFuncionCallbackArgNo(RF);
-          auto *LoopRegion = dyn_cast<Function>(
-              CB.getArgOperand(CallbackArgNo)->stripPointerCasts());
+        if (isRuntimeFunctionWithCallbacks(Callee)) {
+          const unsigned int ArgNo = runtimeFunctionCallbackArgNo(Callee);
+          auto *LoopRegion =
+              dyn_cast<Function>(CB.getArgOperand(ArgNo)->stripPointerCasts());
           if (LoopRegion) {
             auto *FnAA = A.getAAFor<AAKernelInfo>(
                 *this, IRPosition::function(*LoopRegion), 
DepClassTy::OPTIONAL);
@@ -4991,7 +5007,7 @@ struct AAKernelInfoCallSite : AAKernelInfo {
         // state based on the callee state in updateImpl.
         return;
       }
-      if (NumCallees > 1 && !isRuntimeFuncionWithCallbacks(It->getSecond())) {
+      if (NumCallees > 1 && !isRuntimeFunctionWithCallbacks(Callee)) {
         indicatePessimisticFixpoint();
         return;
       }
@@ -5162,7 +5178,7 @@ struct AAKernelInfoCallSite : AAKernelInfo {
         getState() = FnAA->getState();
         return ChangeStatus::CHANGED;
       }
-      if (NumCallees > 1 && !isRuntimeFuncionWithCallbacks(It->getSecond()))
+      if (NumCallees > 1 && !isRuntimeFunctionWithCallbacks(F))
         return indicatePessimisticFixpoint();
 
       CallBase &CB = cast<CallBase>(getAssociatedValue());

_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to