https://github.com/abidh updated https://github.com/llvm/llvm-project/pull/164655
>From 56037a64dbd5f73d2c020dd5d58d2c99758b35d0 Mon Sep 17 00:00:00 2001 From: Abid Qadeer <[email protected]> Date: Tue, 21 Oct 2025 20:53:46 +0100 Subject: [PATCH 1/4] Add callback metadata to runtime functions which take callbacks. --- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 25 ++++++++ .../Frontend/OpenMPIRBuilderTest.cpp | 58 +++++++++++++++++++ 2 files changed, 83 insertions(+) diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index c164d32f8f98c..312e119c4280d 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -750,6 +750,31 @@ OpenMPIRBuilder::getOrCreateRuntimeFunction(Module &M, RuntimeFunction FnID) { *MDNode::get(Ctx, {MDB.createCallbackEncoding( 2, {-1, -1}, /* VarArgsArePassed */ true)})); } + + } else if (FnID == OMPRTL___kmpc_distribute_static_loop_4 || + FnID == OMPRTL___kmpc_distribute_static_loop_4u || + FnID == OMPRTL___kmpc_distribute_static_loop_8 || + FnID == OMPRTL___kmpc_distribute_static_loop_8u || + FnID == OMPRTL___kmpc_distribute_for_static_loop_4 || + FnID == OMPRTL___kmpc_distribute_for_static_loop_4u || + FnID == OMPRTL___kmpc_distribute_for_static_loop_8 || + FnID == OMPRTL___kmpc_distribute_for_static_loop_8u || + FnID == OMPRTL___kmpc_for_static_loop_4 || + FnID == OMPRTL___kmpc_for_static_loop_4u || + FnID == OMPRTL___kmpc_for_static_loop_8 || + FnID == OMPRTL___kmpc_for_static_loop_8u) { + if (!Fn->hasMetadata(LLVMContext::MD_callback)) { + LLVMContext &Ctx = Fn->getContext(); + MDBuilder MDB(Ctx); + // Annotate the callback behavior of the runtime function: + // - The callback callee is argument number 1. + // - The first argument of the callback callee is unknown (-1). + // - The second argument of the callback callee is argument number 2 + Fn->addMetadata( + LLVMContext::MD_callback, + *MDNode::get(Ctx, {MDB.createCallbackEncoding( + 1, {-1, 2}, /* VarArgsArePassed */ false)})); + } } LLVM_DEBUG(dbgs() << "Created OpenMP runtime function " << Fn->getName() diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index d231a778a8a97..aca2153f85c26 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -7957,4 +7957,62 @@ TEST_F(OpenMPIRBuilderTest, spliceBBWithEmptyBB) { EXPECT_FALSE(Terminator->getDbgRecordRange().empty()); } +TEST_F(OpenMPIRBuilderTest, callBackFunctions) { + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.Config.IsTargetDevice = true; + OMPBuilder.initialize(); + + // Test multiple runtime functions that should have callback metadata + std::vector<RuntimeFunction> CallbackFunctions = { + OMPRTL___kmpc_distribute_static_loop_4, + OMPRTL___kmpc_distribute_static_loop_4u, + OMPRTL___kmpc_distribute_static_loop_8, + OMPRTL___kmpc_distribute_static_loop_8u, + OMPRTL___kmpc_distribute_for_static_loop_4, + OMPRTL___kmpc_distribute_for_static_loop_4u, + OMPRTL___kmpc_distribute_for_static_loop_8, + OMPRTL___kmpc_distribute_for_static_loop_8u, + OMPRTL___kmpc_for_static_loop_4, + OMPRTL___kmpc_for_static_loop_4u, + OMPRTL___kmpc_for_static_loop_8, + OMPRTL___kmpc_for_static_loop_8u + }; + + for (RuntimeFunction RF : CallbackFunctions) { + Function *Fn = OMPBuilder.getOrCreateRuntimeFunctionPtr(RF); + ASSERT_NE(Fn, nullptr) << "Function should exist for runtime function"; + + MDNode *CallbackMD = Fn->getMetadata(LLVMContext::MD_callback); + EXPECT_NE(CallbackMD, nullptr) << "Function should have callback metadata"; + + if (CallbackMD) { + // Should have at least one callback + EXPECT_GE(CallbackMD->getNumOperands(), 1U); + + // Test first callback entry + MDNode *FirstCallback = cast<MDNode>(CallbackMD->getOperand(0)); + EXPECT_EQ(FirstCallback->getNumOperands(), 4U); + + // Callee index should be valid + auto *CalleeIdxCM = cast<ConstantAsMetadata>(FirstCallback->getOperand(0)); + uint64_t CalleeIdx = cast<ConstantInt>(CalleeIdxCM->getValue())->getZExtValue(); + EXPECT_EQ(CalleeIdx, 1u); + + // Verify payload arguments re (-1, 2) + auto *Arg0CM = cast<ConstantAsMetadata>(FirstCallback->getOperand(1)); + int64_t Arg0 = cast<ConstantInt>(Arg0CM->getValue())->getSExtValue(); + EXPECT_EQ(Arg0, -1); + auto *Arg1CM = cast<ConstantAsMetadata>(FirstCallback->getOperand(2)); + int64_t Arg1 = cast<ConstantInt>(Arg1CM->getValue())->getSExtValue(); + EXPECT_EQ(Arg1, 2); + + // Verify the varArgs is false. + auto *VarArgCM = cast<ConstantAsMetadata>(FirstCallback->getOperand(3)); + uint64_t VarArg = cast<ConstantInt>(VarArgCM->getValue())->getZExtValue(); + EXPECT_EQ(VarArg, 0u); + } + } +} + + } // namespace >From ba23f45c80f4179a20ca6a2e36f7c75870d0bc0b Mon Sep 17 00:00:00 2001 From: Abid Qadeer <[email protected]> Date: Tue, 21 Oct 2025 21:10:52 +0100 Subject: [PATCH 2/4] Allow SPMDisation of function taking callbacks. --- llvm/lib/Transforms/IPO/OpenMPOpt.cpp | 72 ++++++++++++++- .../test/Transforms/OpenMP/callback_guards.ll | 91 +++++++++++++++++++ 2 files changed, 159 insertions(+), 4 deletions(-) create mode 100644 llvm/test/Transforms/OpenMP/callback_guards.ll diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp index 20fcb7307ff7d..56b49507603eb 100644 --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -3659,6 +3659,46 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> { static const char ID; }; +static bool isRuntimeFuncionWithCallbacks(RuntimeFunction RF) { + switch (RF) { + case OMPRTL___kmpc_distribute_static_loop_4: + case OMPRTL___kmpc_distribute_static_loop_4u: + case OMPRTL___kmpc_distribute_static_loop_8: + case OMPRTL___kmpc_distribute_static_loop_8u: + case OMPRTL___kmpc_distribute_for_static_loop_4: + case OMPRTL___kmpc_distribute_for_static_loop_4u: + case OMPRTL___kmpc_distribute_for_static_loop_8: + case OMPRTL___kmpc_distribute_for_static_loop_8u: + case OMPRTL___kmpc_for_static_loop_4: + case OMPRTL___kmpc_for_static_loop_4u: + case OMPRTL___kmpc_for_static_loop_8: + case OMPRTL___kmpc_for_static_loop_8u: + return true; + default: + return false; + } +} + +static unsigned runtimeFuncionCallbackArgNo(RuntimeFunction RF) { + switch (RF) { + case OMPRTL___kmpc_distribute_static_loop_4: + case OMPRTL___kmpc_distribute_static_loop_4u: + case OMPRTL___kmpc_distribute_static_loop_8: + case OMPRTL___kmpc_distribute_static_loop_8u: + case OMPRTL___kmpc_distribute_for_static_loop_4: + case OMPRTL___kmpc_distribute_for_static_loop_4u: + case OMPRTL___kmpc_distribute_for_static_loop_8: + case OMPRTL___kmpc_distribute_for_static_loop_8u: + case OMPRTL___kmpc_for_static_loop_4: + case OMPRTL___kmpc_for_static_loop_4u: + case OMPRTL___kmpc_for_static_loop_8: + case OMPRTL___kmpc_for_static_loop_8u: + return 1; + default: + llvm_unreachable("Unexpected runtime function!"); + } +} + /// The function kernel info abstract attribute, basically, what can we say /// about a function with regards to the KernelInfoState. struct AAKernelInfoFunction : AAKernelInfo { @@ -4752,6 +4792,30 @@ struct AAKernelInfoFunction : AAKernelInfo { bool AllSPMDStatesWereFixed = true; auto CheckCallInst = [&](Instruction &I) { auto &CB = cast<CallBase>(I); + auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); + const auto &It = + OMPInfoCache.RuntimeFunctionIDMap.find(CB.getCalledFunction()); + if (It != OMPInfoCache.RuntimeFunctionIDMap.end()) { + RuntimeFunction RF = It->getSecond(); + if (isRuntimeFuncionWithCallbacks(RF)) { + const unsigned int CallbackArgNo = runtimeFuncionCallbackArgNo(RF); + auto *LoopRegion = dyn_cast<Function>( + CB.getArgOperand(CallbackArgNo)->stripPointerCasts()); + if (LoopRegion) { + auto *FnAA = A.getAAFor<AAKernelInfo>( + *this, IRPosition::function(*LoopRegion), DepClassTy::OPTIONAL); + if (FnAA) { + getState() ^= FnAA->getState(); + AllSPMDStatesWereFixed &= + FnAA->SPMDCompatibilityTracker.isAtFixpoint(); + AllParallelRegionStatesWereFixed &= + FnAA->ReachedKnownParallelRegions.isAtFixpoint(); + AllParallelRegionStatesWereFixed &= + FnAA->ReachedUnknownParallelRegions.isAtFixpoint(); + } + } + } + } auto *CBAA = A.getAAFor<AAKernelInfo>( *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL); if (!CBAA) @@ -4927,7 +4991,7 @@ struct AAKernelInfoCallSite : AAKernelInfo { // state based on the callee state in updateImpl. return; } - if (NumCallees > 1) { + if (NumCallees > 1 && !isRuntimeFuncionWithCallbacks(It->getSecond())) { indicatePessimisticFixpoint(); return; } @@ -5043,8 +5107,8 @@ struct AAKernelInfoCallSite : AAKernelInfo { // kernel from being SPMD-izable. We mark it as such because we need // further changes in order to also consider the contents of the // callbacks passed to them. - SPMDCompatibilityTracker.indicatePessimisticFixpoint(); - SPMDCompatibilityTracker.insert(&CB); + // SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + // SPMDCompatibilityTracker.insert(&CB); break; default: // Unknown OpenMP runtime calls cannot be executed in SPMD-mode, @@ -5097,7 +5161,7 @@ struct AAKernelInfoCallSite : AAKernelInfo { getState() = FnAA->getState(); return ChangeStatus::CHANGED; } - if (NumCallees > 1) + if (NumCallees > 1 && !isRuntimeFuncionWithCallbacks(It->getSecond())) return indicatePessimisticFixpoint(); CallBase &CB = cast<CallBase>(getAssociatedValue()); diff --git a/llvm/test/Transforms/OpenMP/callback_guards.ll b/llvm/test/Transforms/OpenMP/callback_guards.ll new file mode 100644 index 0000000000000..09343fa196d9d --- /dev/null +++ b/llvm/test/Transforms/OpenMP/callback_guards.ll @@ -0,0 +1,91 @@ +; RUN: opt -passes=openmp-opt -S < %s | FileCheck %s + +%struct.ident_t = type { i32, i32, i32, i32, ptr } +%struct.DynamicEnvironmentTy = type { i16 } +%struct.KernelEnvironmentTy = type { %struct.ConfigurationEnvironmentTy, ptr, ptr } +%struct.ConfigurationEnvironmentTy = type { i8, i8, i8, i32, i32, i32, i32, i32, i32 } + +@0 = private unnamed_addr addrspace(1) constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1 +@1 = private unnamed_addr addrspace(1) constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr addrspacecast (ptr addrspace(1) @0 to ptr) }, align 8 +@__omp_offloading_10303_1849aab__QQmain_l22_exec_mode = weak protected addrspace(1) constant i8 1 +@__omp_offloading_10303_1849aab__QQmain_l22_dynamic_environment = weak_odr protected addrspace(1) global %struct.DynamicEnvironmentTy zeroinitializer +@__omp_offloading_10303_1849aab__QQmain_l22_kernel_environment = weak_odr protected addrspace(1) constant %struct.KernelEnvironmentTy { %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 1, i32 1, i32 256, i32 0, i32 0, i32 4, i32 1024 }, ptr addrspacecast (ptr addrspace(1) @1 to ptr), ptr addrspacecast (ptr addrspace(1) @__omp_offloading_10303_1849aab__QQmain_l22_dynamic_environment to ptr) } + +; Function Attrs: nounwind +define internal void @parallel_func_..omp_par.3(ptr noalias noundef %tid.addr.ascast, ptr noalias noundef %zero.addr.ascast, ptr %0) #1 { +omp.par.entry: + ret void +} + +; Function Attrs: mustprogress +define weak_odr protected amdgpu_kernel void @__omp_offloading_10303_1849aab__QQmain_l22(ptr %0, ptr %1, ptr %2) #4 { +entry: + %7 = call i32 @__kmpc_target_init(ptr addrspacecast (ptr addrspace(1) @__omp_offloading_10303_1849aab__QQmain_l22_kernel_environment to ptr), ptr %0) + %exec_user_code = icmp eq i32 %7, -1 + br i1 %exec_user_code, label %user_code.entry, label %worker.exit + +user_code.entry: ; preds = %entry + call void @__kmpc_distribute_static_loop_4u(ptr addrspacecast (ptr addrspace(1) @1 to ptr), ptr @__omp_offloading_10303_1849aab__QQmain_l22..omp_par, ptr %2, i32 100, i32 0, i8 0) + call void @__kmpc_target_deinit() + br label %worker.exit + +worker.exit: ; preds = %entry + ret void +} + + +define internal void @__omp_offloading_10303_1849aab__QQmain_l22..omp_par(i32 %0, ptr %1) { +omp_loop.body: + %gep = getelementptr { ptr, ptr }, ptr %1, i32 0, i32 1 + %p = load ptr, ptr %gep, align 8 + %5 = add i32 %0, 1 + store i32 %5, ptr %p, align 4 + %omp_global_thread_num = call i32 @__kmpc_global_thread_num(ptr addrspacecast (ptr addrspace(1) @1 to ptr)) + call void @__kmpc_parallel_51(ptr addrspacecast (ptr addrspace(1) @1 to ptr), i32 %omp_global_thread_num, i32 1, i32 -1, i32 -1, ptr @parallel_func_..omp_par.3, ptr @parallel_func_..omp_par.3.wrapper, ptr %1, i64 1) + %6 = load i32, ptr %p, align 4 + %7 = add i32 %6, 1 + store i32 %7, ptr %p, align 4 + ret void +} + +define internal void @parallel_func_..omp_par.3.wrapper(i16 noundef zeroext %0, i32 noundef %1) { +entry: + %addr = alloca i32, align 4, addrspace(5) + %addr.ascast = addrspacecast ptr addrspace(5) %addr to ptr + %zero = alloca i32, align 4, addrspace(5) + %zero.ascast = addrspacecast ptr addrspace(5) %zero to ptr + %global_args = alloca ptr, align 8, addrspace(5) + %global_args.ascast = addrspacecast ptr addrspace(5) %global_args to ptr + store i32 %1, ptr %addr.ascast, align 4 + store i32 0, ptr %zero.ascast, align 4 + call void @__kmpc_get_shared_variables(ptr %global_args.ascast) + %2 = load ptr, ptr %global_args.ascast, align 8 + %3 = getelementptr inbounds ptr, ptr %2, i64 0 + %structArg = load ptr, ptr %3, align 8 + call void @parallel_func_..omp_par.3(ptr %addr.ascast, ptr %zero.ascast, ptr %structArg) + ret void +} + + +declare void @__kmpc_get_shared_variables(ptr) +declare i32 @__kmpc_target_init(ptr, ptr) +declare noalias ptr @__kmpc_alloc_shared(i64) +declare void @__kmpc_target_deinit() +declare i32 @__kmpc_global_thread_num(ptr) +declare void @__kmpc_parallel_51(ptr, i32, i32, i32, i32, ptr, ptr, ptr, i64) +declare !callback !12 void @__kmpc_distribute_static_loop_4u(ptr, ptr, ptr, i32, i32, i8) + +attributes #1 = { nounwind "frame-pointer"="all" } +attributes #4 = { "kernel" } + +!llvm.module.flags = !{!0, !1} + +!0 = !{i32 7, !"openmp-device", i32 52} +!1 = !{i32 7, !"openmp", i32 52} +!12 = !{!13} +!13 = !{i64 1, i64 -1, i64 -1, i1 false} + +; CHECK: define internal void @__omp_offloading_10303_1849aab__QQmain_l22..omp_par( +; CHECK: region.guarded: +; CHECK: region.guarded{{[0-9]+}}: +; CHECK: ret void >From afb76ff88f4c01539beb0c0ebe0553edf0c09a7b Mon Sep 17 00:00:00 2001 From: Abid Qadeer <[email protected]> Date: Wed, 22 Oct 2025 17:46:11 +0100 Subject: [PATCH 3/4] Fix formatting issues. --- .../Frontend/OpenMPIRBuilderTest.cpp | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index aca2153f85c26..376cee8d708e8 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -7964,38 +7964,39 @@ TEST_F(OpenMPIRBuilderTest, callBackFunctions) { // Test multiple runtime functions that should have callback metadata std::vector<RuntimeFunction> CallbackFunctions = { - OMPRTL___kmpc_distribute_static_loop_4, - OMPRTL___kmpc_distribute_static_loop_4u, - OMPRTL___kmpc_distribute_static_loop_8, - OMPRTL___kmpc_distribute_static_loop_8u, - OMPRTL___kmpc_distribute_for_static_loop_4, - OMPRTL___kmpc_distribute_for_static_loop_4u, - OMPRTL___kmpc_distribute_for_static_loop_8, - OMPRTL___kmpc_distribute_for_static_loop_8u, - OMPRTL___kmpc_for_static_loop_4, - OMPRTL___kmpc_for_static_loop_4u, - OMPRTL___kmpc_for_static_loop_8, - OMPRTL___kmpc_for_static_loop_8u - }; + OMPRTL___kmpc_distribute_static_loop_4, + OMPRTL___kmpc_distribute_static_loop_4u, + OMPRTL___kmpc_distribute_static_loop_8, + OMPRTL___kmpc_distribute_static_loop_8u, + OMPRTL___kmpc_distribute_for_static_loop_4, + OMPRTL___kmpc_distribute_for_static_loop_4u, + OMPRTL___kmpc_distribute_for_static_loop_8, + OMPRTL___kmpc_distribute_for_static_loop_8u, + OMPRTL___kmpc_for_static_loop_4, + OMPRTL___kmpc_for_static_loop_4u, + OMPRTL___kmpc_for_static_loop_8, + OMPRTL___kmpc_for_static_loop_8u}; for (RuntimeFunction RF : CallbackFunctions) { Function *Fn = OMPBuilder.getOrCreateRuntimeFunctionPtr(RF); ASSERT_NE(Fn, nullptr) << "Function should exist for runtime function"; - + MDNode *CallbackMD = Fn->getMetadata(LLVMContext::MD_callback); EXPECT_NE(CallbackMD, nullptr) << "Function should have callback metadata"; - + if (CallbackMD) { // Should have at least one callback EXPECT_GE(CallbackMD->getNumOperands(), 1U); - + // Test first callback entry MDNode *FirstCallback = cast<MDNode>(CallbackMD->getOperand(0)); EXPECT_EQ(FirstCallback->getNumOperands(), 4U); - + // Callee index should be valid - auto *CalleeIdxCM = cast<ConstantAsMetadata>(FirstCallback->getOperand(0)); - uint64_t CalleeIdx = cast<ConstantInt>(CalleeIdxCM->getValue())->getZExtValue(); + auto *CalleeIdxCM = + cast<ConstantAsMetadata>(FirstCallback->getOperand(0)); + uint64_t CalleeIdx = + cast<ConstantInt>(CalleeIdxCM->getValue())->getZExtValue(); EXPECT_EQ(CalleeIdx, 1u); // Verify payload arguments re (-1, 2) @@ -8014,5 +8015,4 @@ TEST_F(OpenMPIRBuilderTest, callBackFunctions) { } } - } // namespace >From 331a07bca2370116ab3e6cc82a57efa5ba8132ea Mon Sep 17 00:00:00 2001 From: Abid Qadeer <[email protected]> Date: Wed, 22 Oct 2025 18:55:14 +0100 Subject: [PATCH 4/4] Handle kmpc_reduction_get_fixed_buffer. --- llvm/lib/Transforms/IPO/OpenMPOpt.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp index 56b49507603eb..24550dad49d30 100644 --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -5012,6 +5012,7 @@ struct AAKernelInfoCallSite : AAKernelInfo { case OMPRTL___kmpc_barrier: case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2: case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2: + case OMPRTL___kmpc_reduction_get_fixed_buffer: case OMPRTL___kmpc_error: case OMPRTL___kmpc_flush: case OMPRTL___kmpc_get_hardware_thread_id_in_block: _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
