https://github.com/skatrak created https://github.com/llvm/llvm-project/pull/116051
This patch introduces a `TargetKernelRuntimeAttrs` structure to hold host- evaluated `num_teams`, `thread_limit`, `num_threads` and trip count values passed to the runtime kernel offloading call. Additionally, `createTarget` is extended to take an `IsSPMD` flag, used to influence target device code generation. >From cc5c5cc8b1c8b718ae3d0aece3784416460114bc Mon Sep 17 00:00:00 2001 From: Sergio Afonso <safon...@amd.com> Date: Fri, 8 Nov 2024 17:24:47 +0000 Subject: [PATCH] [OMPIRBuilder] Support runtime number of teams and threads, and SPMD mode This patch introduces a `TargetKernelRuntimeAttrs` structure to hold host- evaluated `num_teams`, `thread_limit`, `num_threads` and trip count values passed to the runtime kernel offloading call. Additionally, `createTarget` is extended to take an `IsSPMD` flag, used to influence target device code generation. --- .../llvm/Frontend/OpenMP/OMPIRBuilder.h | 26 +- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 137 +++++++-- .../Frontend/OpenMPIRBuilderTest.cpp | 281 +++++++++++++++++- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 10 +- 4 files changed, 420 insertions(+), 34 deletions(-) diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index da450ef5adbc14..a85f41e586c514 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -2237,6 +2237,26 @@ class OpenMPIRBuilder { int32_t MinThreads = 1; }; + /// Container to pass LLVM IR runtime values or constants related to the + /// number of teams and threads with which the kernel must be launched, as + /// well as the trip count of the SPMD loop, if it is an SPMD kernel. These + /// must be defined in the host prior to the call to the kernel launch OpenMP + /// RTL function. + struct TargetKernelRuntimeAttrs { + SmallVector<Value *, 3> MaxTeams = {nullptr}; + Value *MinTeams = nullptr; + SmallVector<Value *, 3> TargetThreadLimit = {nullptr}; + SmallVector<Value *, 3> TeamsThreadLimit = {nullptr}; + + /// 'parallel' construct 'num_threads' clause value, if present and it is a + /// target SPMD kernel. + Value *MaxThreads = nullptr; + + /// Total number of iterations of the target SPMD kernel or null if it is a + /// generic kernel. + Value *LoopTripCount = nullptr; + }; + /// Data structure that contains the needed information to construct the /// kernel args vector. struct TargetKernelArgs { @@ -2905,11 +2925,14 @@ class OpenMPIRBuilder { /// /// \param Loc where the target data construct was encountered. /// \param IsOffloadEntry whether it is an offload entry. + /// \param IsSPMD whether it is a target SPMD kernel. /// \param CodeGenIP The insertion point where the call to the outlined /// function should be emitted. /// \param EntryInfo The entry information about the function. /// \param DefaultAttrs Structure containing the default numbers of threads /// and teams to launch the kernel with. + /// \param RuntimeAttrs Structure containing the runtime numbers of threads + /// and teams to launch the kernel with. /// \param Inputs The input values to the region that will be passed. /// as arguments to the outlined function. /// \param BodyGenCB Callback that will generate the region code. @@ -2919,11 +2942,12 @@ class OpenMPIRBuilder { // dependency information as passed in the depend clause // \param HasNowait Whether the target construct has a `nowait` clause or not. InsertPointOrErrorTy createTarget( - const LocationDescription &Loc, bool IsOffloadEntry, + const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD, OpenMPIRBuilder::InsertPointTy AllocaIP, OpenMPIRBuilder::InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, const TargetKernelDefaultAttrs &DefaultAttrs, + const TargetKernelRuntimeAttrs &RuntimeAttrs, SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB, TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB, diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 302d363965c940..f847f60386df85 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -6727,8 +6727,43 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() { return getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_dispatch_deinit); } +static void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List, + Module &M) { + if (List.empty()) + return; + + Type *PtrTy = PointerType::get(M.getContext(), /*AddressSpace=*/0); + + // Convert List to what ConstantArray needs. + SmallVector<Constant *, 8> UsedArray; + UsedArray.reserve(List.size()); + for (auto Item : List) + UsedArray.push_back(ConstantExpr::getPointerBitCastOrAddrSpaceCast( + cast<Constant>(&*Item), PtrTy)); + + ArrayType *ArrTy = ArrayType::get(PtrTy, UsedArray.size()); + auto *GV = + new GlobalVariable(M, ArrTy, false, llvm::GlobalValue::AppendingLinkage, + llvm::ConstantArray::get(ArrTy, UsedArray), Name); + + GV->setSection("llvm.metadata"); +} + +static void +emitExecutionMode(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, + StringRef FunctionName, OMPTgtExecModeFlags Mode, + std::vector<llvm::WeakTrackingVH> &LLVMCompilerUsed) { + auto *Int8Ty = Type::getInt8Ty(Builder.getContext()); + auto *GVMode = new llvm::GlobalVariable( + OMPBuilder.M, Int8Ty, /*isConstant=*/true, + llvm::GlobalValue::WeakAnyLinkage, llvm::ConstantInt::get(Int8Ty, Mode), + Twine(FunctionName, "_exec_mode")); + GVMode->setVisibility(llvm::GlobalVariable::ProtectedVisibility); + LLVMCompilerUsed.emplace_back(GVMode); +} + static Expected<Function *> createOutlinedFunction( - OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, + OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsSPMD, const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs, StringRef FuncName, SmallVectorImpl<Value *> &Inputs, OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc, @@ -6758,6 +6793,27 @@ static Expected<Function *> createOutlinedFunction( auto Func = Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M); + // Forward target-cpu and target-features function attributes from the + // original function to the new outlined function. + Function *ParentFn = Builder.GetInsertBlock()->getParent(); + + auto TargetCpuAttr = ParentFn->getFnAttribute("target-cpu"); + if (TargetCpuAttr.isStringAttribute()) + Func->addFnAttr(TargetCpuAttr); + + auto TargetFeaturesAttr = ParentFn->getFnAttribute("target-features"); + if (TargetFeaturesAttr.isStringAttribute()) + Func->addFnAttr(TargetFeaturesAttr); + + if (OMPBuilder.Config.isTargetDevice()) { + std::vector<llvm::WeakTrackingVH> LLVMCompilerUsed; + emitExecutionMode(OMPBuilder, Builder, FuncName, + IsSPMD ? OMP_TGT_EXEC_MODE_SPMD + : OMP_TGT_EXEC_MODE_GENERIC, + LLVMCompilerUsed); + emitUsed("llvm.compiler.used", LLVMCompilerUsed, OMPBuilder.M); + } + // Save insert point. IRBuilder<>::InsertPointGuard IPG(Builder); // If there's a DISubprogram associated with current function, then @@ -6798,7 +6854,7 @@ static Expected<Function *> createOutlinedFunction( // Insert target init call in the device compilation pass. if (OMPBuilder.Config.isTargetDevice()) Builder.restoreIP( - OMPBuilder.createTargetInit(Builder, /*IsSPMD=*/false, DefaultAttrs)); + OMPBuilder.createTargetInit(Builder, IsSPMD, DefaultAttrs)); BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock(); @@ -6995,7 +7051,7 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder, static Error emitTargetOutlinedFunction( OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry, - TargetRegionEntryInfo &EntryInfo, + bool IsSPMD, TargetRegionEntryInfo &EntryInfo, const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs, Function *&OutlinedFn, Constant *&OutlinedFnID, SmallVectorImpl<Value *> &Inputs, @@ -7004,7 +7060,7 @@ static Error emitTargetOutlinedFunction( OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction = [&](StringRef EntryFnName) { - return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs, + return createOutlinedFunction(OMPBuilder, Builder, IsSPMD, DefaultAttrs, EntryFnName, Inputs, CBFunc, ArgAccessorFuncCB); }; @@ -7304,6 +7360,7 @@ static void emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, OpenMPIRBuilder::InsertPointTy AllocaIP, const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs, + const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs, Function *OutlinedFn, Constant *OutlinedFnID, SmallVectorImpl<Value *> &Args, OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB, @@ -7385,11 +7442,43 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, /*ForEndCall=*/false); SmallVector<Value *, 3> NumTeamsC; + for (auto [DefaultVal, RuntimeVal] : + zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams)) + NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal)); + + // Calculate number of threads: 0 if no clauses specified, otherwise it is the + // minimum between optional THREAD_LIMIT and NUM_THREADS clauses. + auto InitMaxThreadsClause = [&Builder](Value *Clause) { + if (Clause) + Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(), + /*isSigned=*/false); + return Clause; + }; + auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) { + if (Clause) + Result = Result + ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause), + Result, Clause) + : Clause; + }; + + // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so + // the NUM_THREADS clause is overriden by THREAD_LIMIT. SmallVector<Value *, 3> NumThreadsC; - for (auto V : DefaultAttrs.MaxTeams) - NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V)); - for (auto V : DefaultAttrs.MaxThreads) - NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V)); + Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1 + ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads) + : nullptr; + + for (auto [TeamsVal, TargetVal] : llvm::zip_equal( + RuntimeAttrs.TeamsThreadLimit, RuntimeAttrs.TargetThreadLimit)) { + Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal); + Value *NumThreads = InitMaxThreadsClause(TargetVal); + + CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads); + CombineMaxThreadsClauses(MaxThreadsClause, NumThreads); + + NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0)); + } unsigned NumTargetItems = Info.NumberOfPtrs; // TODO: Use correct device ID @@ -7398,14 +7487,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize); Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize, llvm::omp::IdentFlag(0), 0); - // TODO: Use correct NumIterations - Value *NumIterations = Builder.getInt64(0); + + Value *TripCount = RuntimeAttrs.LoopTripCount + ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount, + Builder.getInt64Ty(), + /*isSigned=*/false) + : Builder.getInt64(0); + // TODO: Use correct DynCGGroupMem Value *DynCGGroupMem = Builder.getInt32(0); - KArgs = OpenMPIRBuilder::TargetKernelArgs( - NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC, - DynCGGroupMem, HasNoWait); + KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount, + NumTeamsC, NumThreadsC, + DynCGGroupMem, HasNoWait); // The presence of certain clauses on the target directive require the // explicit generation of the target task. @@ -7427,13 +7521,17 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, } OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget( - const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP, - InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, + const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD, + InsertPointTy AllocaIP, InsertPointTy CodeGenIP, + TargetRegionEntryInfo &EntryInfo, const TargetKernelDefaultAttrs &DefaultAttrs, + const TargetKernelRuntimeAttrs &RuntimeAttrs, SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB, OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc, OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB, SmallVector<DependData> Dependencies, bool HasNowait) { + assert((!RuntimeAttrs.LoopTripCount || IsSPMD) && + "trip count not expected if IsSPMD=false"); if (!updateToLocation(Loc)) return InsertPointTy(); @@ -7446,16 +7544,17 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget( // the target region itself is generated using the callbacks CBFunc // and ArgAccessorFuncCB if (Error Err = emitTargetOutlinedFunction( - *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn, - OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB)) + *this, Builder, IsOffloadEntry, IsSPMD, EntryInfo, DefaultAttrs, + OutlinedFn, OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB)) return Err; // If we are not on the target device, then we need to generate code // to make a remote call (offload) to the previously outlined function // that represents the target region. Do that now. if (!Config.isTargetDevice()) - emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn, - OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait); + emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs, + OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies, + HasNowait); return Builder.saveIP(); } diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index b0688d6215e42d..63be7e775b83c9 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -6122,8 +6122,10 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) { OpenMPIRBuilderConfig Config(false, false, false, false, false, false, false); OMPBuilder.setConfig(Config); F->setName("func"); + F->addFnAttr("target-cpu", "x86-64"); + F->addFnAttr("target-features", "+mmx,+sse"); IRBuilder<> Builder(BB); - auto Int32Ty = Builder.getInt32Ty(); + auto *Int32Ty = Builder.getInt32Ty(); AllocaInst *APtr = Builder.CreateAlloca(Int32Ty, nullptr, "a_ptr"); AllocaInst *BPtr = Builder.CreateAlloca(Int32Ty, nullptr, "b_ptr"); @@ -6183,11 +6185,15 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) { TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17); OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL}); OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = { - /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; - OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = - OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), - Builder.saveIP(), EntryInfo, DefaultAttrs, Inputs, - GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB); + /*MaxTeams=*/{10}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; + OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs; + RuntimeAttrs.TargetThreadLimit[0] = Builder.getInt32(20); + RuntimeAttrs.TeamsThreadLimit[0] = Builder.getInt32(30); + RuntimeAttrs.MaxThreads = Builder.getInt32(40); + OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget( + OmpLoc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, Builder.saveIP(), + Builder.saveIP(), EntryInfo, DefaultAttrs, RuntimeAttrs, Inputs, + GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB); assert(AfterIP && "unexpected error"); Builder.restoreIP(*AfterIP); OMPBuilder.finalize(); @@ -6207,6 +6213,43 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) { StringRef FunctionName = KernelLaunchFunc->getName(); EXPECT_TRUE(FunctionName.starts_with("__tgt_target_kernel")); + // Check num_teams and num_threads in call arguments + EXPECT_TRUE(Call->arg_size() >= 4); + Value *NumTeamsArg = Call->getArgOperand(2); + EXPECT_TRUE(isa<ConstantInt>(NumTeamsArg)); + EXPECT_EQ(10U, cast<ConstantInt>(NumTeamsArg)->getZExtValue()); + Value *NumThreadsArg = Call->getArgOperand(3); + EXPECT_TRUE(isa<ConstantInt>(NumThreadsArg)); + EXPECT_EQ(20U, cast<ConstantInt>(NumThreadsArg)->getZExtValue()); + + // Check num_teams and num_threads kernel arguments (use number 5 starting + // from the end and counting the call to __tgt_target_kernel as the first use) + Value *KernelArgs = Call->getArgOperand(Call->arg_size() - 1); + EXPECT_TRUE(KernelArgs->getNumUses() >= 4); + Value *NumTeamsGetElemPtr = *std::next(KernelArgs->user_begin(), 3); + EXPECT_TRUE(isa<GetElementPtrInst>(NumTeamsGetElemPtr)); + Value *NumTeamsStore = NumTeamsGetElemPtr->getUniqueUndroppableUser(); + EXPECT_TRUE(isa<StoreInst>(NumTeamsStore)); + Value *NumTeamsStoreArg = cast<StoreInst>(NumTeamsStore)->getValueOperand(); + EXPECT_TRUE(isa<ConstantDataSequential>(NumTeamsStoreArg)); + auto *NumTeamsStoreValue = cast<ConstantDataSequential>(NumTeamsStoreArg); + EXPECT_EQ(3U, NumTeamsStoreValue->getNumElements()); + EXPECT_EQ(10U, NumTeamsStoreValue->getElementAsInteger(0)); + EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(1)); + EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(2)); + Value *NumThreadsGetElemPtr = *std::next(KernelArgs->user_begin(), 2); + EXPECT_TRUE(isa<GetElementPtrInst>(NumThreadsGetElemPtr)); + Value *NumThreadsStore = NumThreadsGetElemPtr->getUniqueUndroppableUser(); + EXPECT_TRUE(isa<StoreInst>(NumThreadsStore)); + Value *NumThreadsStoreArg = + cast<StoreInst>(NumThreadsStore)->getValueOperand(); + EXPECT_TRUE(isa<ConstantDataSequential>(NumThreadsStoreArg)); + auto *NumThreadsStoreValue = cast<ConstantDataSequential>(NumThreadsStoreArg); + EXPECT_EQ(3U, NumThreadsStoreValue->getNumElements()); + EXPECT_EQ(20U, NumThreadsStoreValue->getElementAsInteger(0)); + EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(1)); + EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(2)); + // Check the fallback call BasicBlock *FallbackBlock = Branch->getSuccessor(0); Iter = FallbackBlock->rbegin(); @@ -6228,6 +6271,13 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) { StringRef FunctionName2 = OutlinedFunc->getName(); EXPECT_TRUE(FunctionName2.starts_with("__omp_offloading")); + // Check that target-cpu and target-features were propagated to the outlined + // function + EXPECT_EQ(OutlinedFunc->getFnAttribute("target-cpu"), + F->getFnAttribute("target-cpu")); + EXPECT_EQ(OutlinedFunc->getFnAttribute("target-features"), + F->getFnAttribute("target-features")); + EXPECT_FALSE(verifyModule(*M, &errs())); } @@ -6238,6 +6288,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { OMPBuilder.initialize(); F->setName("func"); + F->addFnAttr("target-cpu", "gfx90a"); + F->addFnAttr("target-features", "+gfx9-insts,+wavefrontsize64"); IRBuilder<> Builder(BB); OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); @@ -6297,9 +6349,11 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = { /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; + OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs; OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget( - Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs, - CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB); + Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, EntryIP, EntryIP, + EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB, + BodyGenCB, SimpleArgAccessorCB); assert(AfterIP && "unexpected error"); Builder.restoreIP(*AfterIP); @@ -6312,6 +6366,13 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { Function *OutlinedFn = TargetStore->getFunction(); EXPECT_NE(F, OutlinedFn); + // Check that target-cpu and target-features were propagated to the outlined + // function + EXPECT_EQ(OutlinedFn->getFnAttribute("target-cpu"), + F->getFnAttribute("target-cpu")); + EXPECT_EQ(OutlinedFn->getFnAttribute("target-features"), + F->getFnAttribute("target-features")); + EXPECT_TRUE(OutlinedFn->hasWeakODRLinkage()); // Account for the "implicit" first argument. EXPECT_EQ(OutlinedFn->getName(), "__omp_offloading_1_2_parent_l3"); @@ -6378,6 +6439,204 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { auto *ExitBlock = EntryBlockBranch->getSuccessor(1); EXPECT_EQ(ExitBlock->getName(), "worker.exit"); EXPECT_TRUE(isa<ReturnInst>(ExitBlock->getFirstNonPHI())); + + // Check global exec_mode. + GlobalVariable *Used = M->getGlobalVariable("llvm.compiler.used"); + EXPECT_NE(Used, nullptr); + Constant *UsedInit = Used->getInitializer(); + EXPECT_NE(UsedInit, nullptr); + EXPECT_TRUE(isa<ConstantArray>(UsedInit)); + auto *UsedInitData = cast<ConstantArray>(UsedInit); + EXPECT_EQ(1U, UsedInitData->getNumOperands()); + Constant *ExecMode = UsedInitData->getOperand(0); + EXPECT_TRUE(isa<GlobalVariable>(ExecMode)); + Constant *ExecModeValue = cast<GlobalVariable>(ExecMode)->getInitializer(); + EXPECT_NE(ExecModeValue, nullptr); + EXPECT_TRUE(isa<ConstantInt>(ExecModeValue)); + EXPECT_EQ(OMP_TGT_EXEC_MODE_GENERIC, + cast<ConstantInt>(ExecModeValue)->getZExtValue()); +} + +TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + OpenMPIRBuilderConfig Config(/*IsTargetDevice=*/false, /*IsGPU=*/false, + /*OpenMPOffloadMandatory=*/false, + /*HasRequiresReverseOffload=*/false, + /*HasRequiresUnifiedAddress=*/false, + /*HasRequiresUnifiedSharedMemory=*/false, + /*HasRequiresDynamicAllocators=*/false); + OMPBuilder.setConfig(Config); + F->setName("func"); + IRBuilder<> Builder(BB); + + auto BodyGenCB = [&](InsertPointTy, + InsertPointTy CodeGenIP) -> InsertPointTy { + Builder.restoreIP(CodeGenIP); + return Builder.saveIP(); + }; + + auto SimpleArgAccessorCB = + [&](llvm::Argument &, llvm::Value *, llvm::Value *&, + llvm::OpenMPIRBuilder::InsertPointTy, + llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) { + Builder.restoreIP(CodeGenIP); + return Builder.saveIP(); + }; + + llvm::SmallVector<llvm::Value *> Inputs; + llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos; + auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy) + -> llvm::OpenMPIRBuilder::MapInfosTy & { return CombinedInfos; }; + + TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17); + OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL}); + OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = { + /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; + OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs; + RuntimeAttrs.LoopTripCount = Builder.getInt64(1000); + OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget( + OmpLoc, /*IsOffloadEntry=*/true, /*IsSPMD=*/true, Builder.saveIP(), + Builder.saveIP(), EntryInfo, DefaultAttrs, RuntimeAttrs, Inputs, + GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB); + assert(AfterIP && "unexpected error"); + Builder.restoreIP(*AfterIP); + OMPBuilder.finalize(); + Builder.CreateRetVoid(); + + // Check the kernel launch sequence + auto Iter = F->getEntryBlock().rbegin(); + EXPECT_TRUE(isa<BranchInst>(&*(Iter))); + BranchInst *Branch = dyn_cast<BranchInst>(&*(Iter)); + EXPECT_TRUE(isa<CmpInst>(&*(++Iter))); + EXPECT_TRUE(isa<CallInst>(&*(++Iter))); + CallInst *Call = dyn_cast<CallInst>(&*(Iter)); + + // Check that the kernel launch function is called + Function *KernelLaunchFunc = Call->getCalledFunction(); + EXPECT_NE(KernelLaunchFunc, nullptr); + StringRef FunctionName = KernelLaunchFunc->getName(); + EXPECT_TRUE(FunctionName.starts_with("__tgt_target_kernel")); + + // Check the trip count kernel argument (use number 5 starting from the end + // and counting the call to __tgt_target_kernel as the first use) + Value *KernelArgs = Call->getArgOperand(Call->arg_size() - 1); + EXPECT_TRUE(KernelArgs->getNumUses() >= 6); + Value *TripCountGetElemPtr = *std::next(KernelArgs->user_begin(), 5); + EXPECT_TRUE(isa<GetElementPtrInst>(TripCountGetElemPtr)); + Value *TripCountStore = TripCountGetElemPtr->getUniqueUndroppableUser(); + EXPECT_TRUE(isa<StoreInst>(TripCountStore)); + Value *TripCountStoreArg = cast<StoreInst>(TripCountStore)->getValueOperand(); + EXPECT_TRUE(isa<ConstantInt>(TripCountStoreArg)); + EXPECT_EQ(1000U, cast<ConstantInt>(TripCountStoreArg)->getZExtValue()); + + // Check the fallback call + BasicBlock *FallbackBlock = Branch->getSuccessor(0); + Iter = FallbackBlock->rbegin(); + CallInst *FCall = dyn_cast<CallInst>(&*(++Iter)); + // 'F' has a dummy DISubprogram which causes OutlinedFunc to also + // have a DISubprogram. In this case, the call to OutlinedFunc needs + // to have a debug loc, otherwise verifier will complain. + FCall->setDebugLoc(DL); + EXPECT_NE(FCall, nullptr); + + // Check that the outlined function exists with the expected prefix + Function *OutlinedFunc = FCall->getCalledFunction(); + EXPECT_NE(OutlinedFunc, nullptr); + StringRef FunctionName2 = OutlinedFunc->getName(); + EXPECT_TRUE(FunctionName2.starts_with("__omp_offloading")); + + EXPECT_FALSE(verifyModule(*M, &errs())); +} + +TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) { + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.setConfig( + OpenMPIRBuilderConfig(/*IsTargetDevice=*/true, /*IsGPU=*/false, + /*OpenMPOffloadMandatory=*/false, + /*HasRequiresReverseOffload=*/false, + /*HasRequiresUnifiedAddress=*/false, + /*HasRequiresUnifiedSharedMemory=*/false, + /*HasRequiresDynamicAllocators=*/false)); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); + + Function *OutlinedFn = nullptr; + llvm::SmallVector<llvm::Value *> CapturedArgs; + + auto SimpleArgAccessorCB = + [&](llvm::Argument &, llvm::Value *, llvm::Value *&, + llvm::OpenMPIRBuilder::InsertPointTy, + llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) { + Builder.restoreIP(CodeGenIP); + return Builder.saveIP(); + }; + + llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos; + auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy) + -> llvm::OpenMPIRBuilder::MapInfosTy & { return CombinedInfos; }; + + auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy, + OpenMPIRBuilder::InsertPointTy CodeGenIP) + -> OpenMPIRBuilder::InsertPointTy { + Builder.restoreIP(CodeGenIP); + OutlinedFn = CodeGenIP.getBlock()->getParent(); + return Builder.saveIP(); + }; + + IRBuilder<>::InsertPoint EntryIP(&F->getEntryBlock(), + F->getEntryBlock().getFirstInsertionPt()); + TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2, + /*Line=*/3, /*Count=*/0); + + OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = { + /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; + OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs; + OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget( + Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/true, EntryIP, EntryIP, + EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB, + BodyGenCB, SimpleArgAccessorCB); + assert(AfterIP && "unexpected error"); + Builder.restoreIP(*AfterIP); + + Builder.CreateRetVoid(); + OMPBuilder.finalize(); + + // Check outlined function + EXPECT_FALSE(verifyModule(*M, &errs())); + EXPECT_NE(OutlinedFn, nullptr); + EXPECT_NE(F, OutlinedFn); + + // Check that target-cpu and target-features were propagated to the outlined + // function + EXPECT_EQ(OutlinedFn->getFnAttribute("target-cpu"), + F->getFnAttribute("target-cpu")); + EXPECT_EQ(OutlinedFn->getFnAttribute("target-features"), + F->getFnAttribute("target-features")); + + EXPECT_TRUE(OutlinedFn->hasWeakODRLinkage()); + // Account for the "implicit" first argument. + EXPECT_EQ(OutlinedFn->getName(), "__omp_offloading_1_2_parent_l3"); + EXPECT_EQ(OutlinedFn->arg_size(), 1U); + + // Check global exec_mode. + GlobalVariable *Used = M->getGlobalVariable("llvm.compiler.used"); + EXPECT_NE(Used, nullptr); + Constant *UsedInit = Used->getInitializer(); + EXPECT_NE(UsedInit, nullptr); + EXPECT_TRUE(isa<ConstantArray>(UsedInit)); + auto *UsedInitData = cast<ConstantArray>(UsedInit); + EXPECT_EQ(1U, UsedInitData->getNumOperands()); + Constant *ExecMode = UsedInitData->getOperand(0); + EXPECT_TRUE(isa<GlobalVariable>(ExecMode)); + Constant *ExecModeValue = cast<GlobalVariable>(ExecMode)->getInitializer(); + EXPECT_NE(ExecModeValue, nullptr); + EXPECT_TRUE(isa<ConstantInt>(ExecModeValue)); + EXPECT_EQ(OMP_TGT_EXEC_MODE_SPMD, + cast<ConstantInt>(ExecModeValue)->getZExtValue()); } TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) { @@ -6448,9 +6707,11 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) { OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = { /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; + OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs; OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget( - Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs, - CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB); + Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, EntryIP, EntryIP, + EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB, + BodyGenCB, SimpleArgAccessorCB); assert(AfterIP && "unexpected error"); Builder.restoreIP(*AfterIP); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index d3c3839accb7e7..9bdf3e11496f3a 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -3936,9 +3936,11 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, allocaIP, codeGenIP); }; - // TODO: Populate default attributes based on the construct and clauses. + // TODO: Populate default and runtime attributes based on the construct and + // clauses. llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs = { /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; + llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs; llvm::SmallVector<llvm::Value *, 4> kernelInput; for (size_t i = 0; i < mapVars.size(); ++i) { @@ -3957,9 +3959,9 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = moduleTranslation.getOpenMPBuilder()->createTarget( - ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo, - defaultAttrs, kernelInput, genMapInfoCB, bodyCB, argAccessorCB, dds, - targetOp.getNowait()); + ompLoc, isOffloadEntry, /*IsSPMD=*/false, allocaIP, builder.saveIP(), + entryInfo, defaultAttrs, runtimeAttrs, kernelInput, genMapInfoCB, + bodyCB, argAccessorCB, dds, targetOp.getNowait()); if (failed(handleError(afterIP, opInst))) return failure(); _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits