https://github.com/skatrak created 
https://github.com/llvm/llvm-project/pull/116051

This patch introduces a `TargetKernelRuntimeAttrs` structure to hold host- 
evaluated `num_teams`, `thread_limit`, `num_threads` and trip count values 
passed to the runtime kernel offloading call.

Additionally, `createTarget` is extended to take an `IsSPMD` flag, used to 
influence target device code generation.

>From cc5c5cc8b1c8b718ae3d0aece3784416460114bc Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safon...@amd.com>
Date: Fri, 8 Nov 2024 17:24:47 +0000
Subject: [PATCH] [OMPIRBuilder] Support runtime number of teams and threads,
 and SPMD mode

This patch introduces a `TargetKernelRuntimeAttrs` structure to hold host-
evaluated `num_teams`, `thread_limit`, `num_threads` and trip count values
passed to the runtime kernel offloading call.

Additionally, `createTarget` is extended to take an `IsSPMD` flag, used to
influence target device code generation.
---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  26 +-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 137 +++++++--
 .../Frontend/OpenMPIRBuilderTest.cpp          | 281 +++++++++++++++++-
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  10 +-
 4 files changed, 420 insertions(+), 34 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h 
b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index da450ef5adbc14..a85f41e586c514 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2237,6 +2237,26 @@ class OpenMPIRBuilder {
     int32_t MinThreads = 1;
   };
 
+  /// Container to pass LLVM IR runtime values or constants related to the
+  /// number of teams and threads with which the kernel must be launched, as
+  /// well as the trip count of the SPMD loop, if it is an SPMD kernel. These
+  /// must be defined in the host prior to the call to the kernel launch OpenMP
+  /// RTL function.
+  struct TargetKernelRuntimeAttrs {
+    SmallVector<Value *, 3> MaxTeams = {nullptr};
+    Value *MinTeams = nullptr;
+    SmallVector<Value *, 3> TargetThreadLimit = {nullptr};
+    SmallVector<Value *, 3> TeamsThreadLimit = {nullptr};
+
+    /// 'parallel' construct 'num_threads' clause value, if present and it is a
+    /// target SPMD kernel.
+    Value *MaxThreads = nullptr;
+
+    /// Total number of iterations of the target SPMD kernel or null if it is a
+    /// generic kernel.
+    Value *LoopTripCount = nullptr;
+  };
+
   /// Data structure that contains the needed information to construct the
   /// kernel args vector.
   struct TargetKernelArgs {
@@ -2905,11 +2925,14 @@ class OpenMPIRBuilder {
   ///
   /// \param Loc where the target data construct was encountered.
   /// \param IsOffloadEntry whether it is an offload entry.
+  /// \param IsSPMD whether it is a target SPMD kernel.
   /// \param CodeGenIP The insertion point where the call to the outlined
   /// function should be emitted.
   /// \param EntryInfo The entry information about the function.
   /// \param DefaultAttrs Structure containing the default numbers of threads
   ///        and teams to launch the kernel with.
+  /// \param RuntimeAttrs Structure containing the runtime numbers of threads
+  ///        and teams to launch the kernel with.
   /// \param Inputs The input values to the region that will be passed.
   /// as arguments to the outlined function.
   /// \param BodyGenCB Callback that will generate the region code.
@@ -2919,11 +2942,12 @@ class OpenMPIRBuilder {
   // dependency information as passed in the depend clause
   // \param HasNowait Whether the target construct has a `nowait` clause or 
not.
   InsertPointOrErrorTy createTarget(
-      const LocationDescription &Loc, bool IsOffloadEntry,
+      const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
       OpenMPIRBuilder::InsertPointTy AllocaIP,
       OpenMPIRBuilder::InsertPointTy CodeGenIP,
       TargetRegionEntryInfo &EntryInfo,
       const TargetKernelDefaultAttrs &DefaultAttrs,
+      const TargetKernelRuntimeAttrs &RuntimeAttrs,
       SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
       TargetBodyGenCallbackTy BodyGenCB,
       TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp 
b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 302d363965c940..f847f60386df85 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -6727,8 +6727,43 @@ FunctionCallee 
OpenMPIRBuilder::createDispatchDeinitFunction() {
   return getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_dispatch_deinit);
 }
 
+static void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List,
+                     Module &M) {
+  if (List.empty())
+    return;
+
+  Type *PtrTy = PointerType::get(M.getContext(), /*AddressSpace=*/0);
+
+  // Convert List to what ConstantArray needs.
+  SmallVector<Constant *, 8> UsedArray;
+  UsedArray.reserve(List.size());
+  for (auto Item : List)
+    UsedArray.push_back(ConstantExpr::getPointerBitCastOrAddrSpaceCast(
+        cast<Constant>(&*Item), PtrTy));
+
+  ArrayType *ArrTy = ArrayType::get(PtrTy, UsedArray.size());
+  auto *GV =
+      new GlobalVariable(M, ArrTy, false, llvm::GlobalValue::AppendingLinkage,
+                         llvm::ConstantArray::get(ArrTy, UsedArray), Name);
+
+  GV->setSection("llvm.metadata");
+}
+
+static void
+emitExecutionMode(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+                  StringRef FunctionName, OMPTgtExecModeFlags Mode,
+                  std::vector<llvm::WeakTrackingVH> &LLVMCompilerUsed) {
+  auto *Int8Ty = Type::getInt8Ty(Builder.getContext());
+  auto *GVMode = new llvm::GlobalVariable(
+      OMPBuilder.M, Int8Ty, /*isConstant=*/true,
+      llvm::GlobalValue::WeakAnyLinkage, llvm::ConstantInt::get(Int8Ty, Mode),
+      Twine(FunctionName, "_exec_mode"));
+  GVMode->setVisibility(llvm::GlobalVariable::ProtectedVisibility);
+  LLVMCompilerUsed.emplace_back(GVMode);
+}
+
 static Expected<Function *> createOutlinedFunction(
-    OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+    OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsSPMD,
     const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
     StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
     OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
@@ -6758,6 +6793,27 @@ static Expected<Function *> createOutlinedFunction(
   auto Func =
       Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M);
 
+  // Forward target-cpu and target-features function attributes from the
+  // original function to the new outlined function.
+  Function *ParentFn = Builder.GetInsertBlock()->getParent();
+
+  auto TargetCpuAttr = ParentFn->getFnAttribute("target-cpu");
+  if (TargetCpuAttr.isStringAttribute())
+    Func->addFnAttr(TargetCpuAttr);
+
+  auto TargetFeaturesAttr = ParentFn->getFnAttribute("target-features");
+  if (TargetFeaturesAttr.isStringAttribute())
+    Func->addFnAttr(TargetFeaturesAttr);
+
+  if (OMPBuilder.Config.isTargetDevice()) {
+    std::vector<llvm::WeakTrackingVH> LLVMCompilerUsed;
+    emitExecutionMode(OMPBuilder, Builder, FuncName,
+                      IsSPMD ? OMP_TGT_EXEC_MODE_SPMD
+                             : OMP_TGT_EXEC_MODE_GENERIC,
+                      LLVMCompilerUsed);
+    emitUsed("llvm.compiler.used", LLVMCompilerUsed, OMPBuilder.M);
+  }
+
   // Save insert point.
   IRBuilder<>::InsertPointGuard IPG(Builder);
   // If there's a DISubprogram associated with current function, then
@@ -6798,7 +6854,7 @@ static Expected<Function *> createOutlinedFunction(
   // Insert target init call in the device compilation pass.
   if (OMPBuilder.Config.isTargetDevice())
     Builder.restoreIP(
-        OMPBuilder.createTargetInit(Builder, /*IsSPMD=*/false, DefaultAttrs));
+        OMPBuilder.createTargetInit(Builder, IsSPMD, DefaultAttrs));
 
   BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
 
@@ -6995,7 +7051,7 @@ static Function 
*emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
 
 static Error emitTargetOutlinedFunction(
     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
-    TargetRegionEntryInfo &EntryInfo,
+    bool IsSPMD, TargetRegionEntryInfo &EntryInfo,
     const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
     Function *&OutlinedFn, Constant *&OutlinedFnID,
     SmallVectorImpl<Value *> &Inputs,
@@ -7004,7 +7060,7 @@ static Error emitTargetOutlinedFunction(
 
   OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
       [&](StringRef EntryFnName) {
-        return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
+        return createOutlinedFunction(OMPBuilder, Builder, IsSPMD, 
DefaultAttrs,
                                       EntryFnName, Inputs, CBFunc,
                                       ArgAccessorFuncCB);
       };
@@ -7304,6 +7360,7 @@ static void
 emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
                OpenMPIRBuilder::InsertPointTy AllocaIP,
                const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+               const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
                Function *OutlinedFn, Constant *OutlinedFnID,
                SmallVectorImpl<Value *> &Args,
                OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
@@ -7385,11 +7442,43 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, 
IRBuilderBase &Builder,
                                          /*ForEndCall=*/false);
 
   SmallVector<Value *, 3> NumTeamsC;
+  for (auto [DefaultVal, RuntimeVal] :
+       zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
+    NumTeamsC.push_back(RuntimeVal ? RuntimeVal : 
Builder.getInt32(DefaultVal));
+
+  // Calculate number of threads: 0 if no clauses specified, otherwise it is 
the
+  // minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
+  auto InitMaxThreadsClause = [&Builder](Value *Clause) {
+    if (Clause)
+      Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
+                                     /*isSigned=*/false);
+    return Clause;
+  };
+  auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
+    if (Clause)
+      Result = Result
+                   ? Builder.CreateSelect(Builder.CreateICmpULT(Result, 
Clause),
+                                          Result, Clause)
+                   : Clause;
+  };
+
+  // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
+  // the NUM_THREADS clause is overriden by THREAD_LIMIT.
   SmallVector<Value *, 3> NumThreadsC;
-  for (auto V : DefaultAttrs.MaxTeams)
-    NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
-  for (auto V : DefaultAttrs.MaxThreads)
-    NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
+  Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1
+                                ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
+                                : nullptr;
+
+  for (auto [TeamsVal, TargetVal] : llvm::zip_equal(
+           RuntimeAttrs.TeamsThreadLimit, RuntimeAttrs.TargetThreadLimit)) {
+    Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
+    Value *NumThreads = InitMaxThreadsClause(TargetVal);
+
+    CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
+    CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
+
+    NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
+  }
 
   unsigned NumTargetItems = Info.NumberOfPtrs;
   // TODO: Use correct device ID
@@ -7398,14 +7487,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, 
IRBuilderBase &Builder,
   Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
   Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
                                              llvm::omp::IdentFlag(0), 0);
-  // TODO: Use correct NumIterations
-  Value *NumIterations = Builder.getInt64(0);
+
+  Value *TripCount = RuntimeAttrs.LoopTripCount
+                         ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
+                                                 Builder.getInt64Ty(),
+                                                 /*isSigned=*/false)
+                         : Builder.getInt64(0);
+
   // TODO: Use correct DynCGGroupMem
   Value *DynCGGroupMem = Builder.getInt32(0);
 
-  KArgs = OpenMPIRBuilder::TargetKernelArgs(
-      NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
-      DynCGGroupMem, HasNoWait);
+  KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
+                                            NumTeamsC, NumThreadsC,
+                                            DynCGGroupMem, HasNoWait);
 
   // The presence of certain clauses on the target directive require the
   // explicit generation of the target task.
@@ -7427,13 +7521,17 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, 
IRBuilderBase &Builder,
 }
 
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
-    const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy 
AllocaIP,
-    InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
+    const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
+    InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
+    TargetRegionEntryInfo &EntryInfo,
     const TargetKernelDefaultAttrs &DefaultAttrs,
+    const TargetKernelRuntimeAttrs &RuntimeAttrs,
     SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
     OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
     SmallVector<DependData> Dependencies, bool HasNowait) {
+  assert((!RuntimeAttrs.LoopTripCount || IsSPMD) &&
+         "trip count not expected if IsSPMD=false");
 
   if (!updateToLocation(Loc))
     return InsertPointTy();
@@ -7446,16 +7544,17 @@ OpenMPIRBuilder::InsertPointOrErrorTy 
OpenMPIRBuilder::createTarget(
   // the target region itself is generated using the callbacks CBFunc
   // and ArgAccessorFuncCB
   if (Error Err = emitTargetOutlinedFunction(
-          *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
-          OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
+          *this, Builder, IsOffloadEntry, IsSPMD, EntryInfo, DefaultAttrs,
+          OutlinedFn, OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
     return Err;
 
   // If we are not on the target device, then we need to generate code
   // to make a remote call (offload) to the previously outlined function
   // that represents the target region. Do that now.
   if (!Config.isTargetDevice())
-    emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn,
-                   OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
+    emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
+                   OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
+                   HasNowait);
   return Builder.saveIP();
 }
 
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp 
b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index b0688d6215e42d..63be7e775b83c9 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6122,8 +6122,10 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   OpenMPIRBuilderConfig Config(false, false, false, false, false, false, 
false);
   OMPBuilder.setConfig(Config);
   F->setName("func");
+  F->addFnAttr("target-cpu", "x86-64");
+  F->addFnAttr("target-features", "+mmx,+sse");
   IRBuilder<> Builder(BB);
-  auto Int32Ty = Builder.getInt32Ty();
+  auto *Int32Ty = Builder.getInt32Ty();
 
   AllocaInst *APtr = Builder.CreateAlloca(Int32Ty, nullptr, "a_ptr");
   AllocaInst *BPtr = Builder.CreateAlloca(Int32Ty, nullptr, "b_ptr");
@@ -6183,11 +6185,15 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
   OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
   OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
-      /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
-  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
-      OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, 
Builder.saveIP(),
-                              Builder.saveIP(), EntryInfo, DefaultAttrs, 
Inputs,
-                              GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+      /*MaxTeams=*/{10}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
+  RuntimeAttrs.TargetThreadLimit[0] = Builder.getInt32(20);
+  RuntimeAttrs.TeamsThreadLimit[0] = Builder.getInt32(30);
+  RuntimeAttrs.MaxThreads = Builder.getInt32(40);
+  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+      OmpLoc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, Builder.saveIP(),
+      Builder.saveIP(), EntryInfo, DefaultAttrs, RuntimeAttrs, Inputs,
+      GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
   assert(AfterIP && "unexpected error");
   Builder.restoreIP(*AfterIP);
   OMPBuilder.finalize();
@@ -6207,6 +6213,43 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   StringRef FunctionName = KernelLaunchFunc->getName();
   EXPECT_TRUE(FunctionName.starts_with("__tgt_target_kernel"));
 
+  // Check num_teams and num_threads in call arguments
+  EXPECT_TRUE(Call->arg_size() >= 4);
+  Value *NumTeamsArg = Call->getArgOperand(2);
+  EXPECT_TRUE(isa<ConstantInt>(NumTeamsArg));
+  EXPECT_EQ(10U, cast<ConstantInt>(NumTeamsArg)->getZExtValue());
+  Value *NumThreadsArg = Call->getArgOperand(3);
+  EXPECT_TRUE(isa<ConstantInt>(NumThreadsArg));
+  EXPECT_EQ(20U, cast<ConstantInt>(NumThreadsArg)->getZExtValue());
+
+  // Check num_teams and num_threads kernel arguments (use number 5 starting
+  // from the end and counting the call to __tgt_target_kernel as the first 
use)
+  Value *KernelArgs = Call->getArgOperand(Call->arg_size() - 1);
+  EXPECT_TRUE(KernelArgs->getNumUses() >= 4);
+  Value *NumTeamsGetElemPtr = *std::next(KernelArgs->user_begin(), 3);
+  EXPECT_TRUE(isa<GetElementPtrInst>(NumTeamsGetElemPtr));
+  Value *NumTeamsStore = NumTeamsGetElemPtr->getUniqueUndroppableUser();
+  EXPECT_TRUE(isa<StoreInst>(NumTeamsStore));
+  Value *NumTeamsStoreArg = cast<StoreInst>(NumTeamsStore)->getValueOperand();
+  EXPECT_TRUE(isa<ConstantDataSequential>(NumTeamsStoreArg));
+  auto *NumTeamsStoreValue = cast<ConstantDataSequential>(NumTeamsStoreArg);
+  EXPECT_EQ(3U, NumTeamsStoreValue->getNumElements());
+  EXPECT_EQ(10U, NumTeamsStoreValue->getElementAsInteger(0));
+  EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(1));
+  EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(2));
+  Value *NumThreadsGetElemPtr = *std::next(KernelArgs->user_begin(), 2);
+  EXPECT_TRUE(isa<GetElementPtrInst>(NumThreadsGetElemPtr));
+  Value *NumThreadsStore = NumThreadsGetElemPtr->getUniqueUndroppableUser();
+  EXPECT_TRUE(isa<StoreInst>(NumThreadsStore));
+  Value *NumThreadsStoreArg =
+      cast<StoreInst>(NumThreadsStore)->getValueOperand();
+  EXPECT_TRUE(isa<ConstantDataSequential>(NumThreadsStoreArg));
+  auto *NumThreadsStoreValue = 
cast<ConstantDataSequential>(NumThreadsStoreArg);
+  EXPECT_EQ(3U, NumThreadsStoreValue->getNumElements());
+  EXPECT_EQ(20U, NumThreadsStoreValue->getElementAsInteger(0));
+  EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(1));
+  EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(2));
+
   // Check the fallback call
   BasicBlock *FallbackBlock = Branch->getSuccessor(0);
   Iter = FallbackBlock->rbegin();
@@ -6228,6 +6271,13 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   StringRef FunctionName2 = OutlinedFunc->getName();
   EXPECT_TRUE(FunctionName2.starts_with("__omp_offloading"));
 
+  // Check that target-cpu and target-features were propagated to the outlined
+  // function
+  EXPECT_EQ(OutlinedFunc->getFnAttribute("target-cpu"),
+            F->getFnAttribute("target-cpu"));
+  EXPECT_EQ(OutlinedFunc->getFnAttribute("target-features"),
+            F->getFnAttribute("target-features"));
+
   EXPECT_FALSE(verifyModule(*M, &errs()));
 }
 
@@ -6238,6 +6288,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
   OMPBuilder.initialize();
 
   F->setName("func");
+  F->addFnAttr("target-cpu", "gfx90a");
+  F->addFnAttr("target-features", "+gfx9-insts,+wavefrontsize64");
   IRBuilder<> Builder(BB);
   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
 
@@ -6297,9 +6349,11 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
 
   OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
       /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
-      Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs,
-      CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+      Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, EntryIP, EntryIP,
+      EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB,
+      BodyGenCB, SimpleArgAccessorCB);
   assert(AfterIP && "unexpected error");
   Builder.restoreIP(*AfterIP);
 
@@ -6312,6 +6366,13 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
   Function *OutlinedFn = TargetStore->getFunction();
   EXPECT_NE(F, OutlinedFn);
 
+  // Check that target-cpu and target-features were propagated to the outlined
+  // function
+  EXPECT_EQ(OutlinedFn->getFnAttribute("target-cpu"),
+            F->getFnAttribute("target-cpu"));
+  EXPECT_EQ(OutlinedFn->getFnAttribute("target-features"),
+            F->getFnAttribute("target-features"));
+
   EXPECT_TRUE(OutlinedFn->hasWeakODRLinkage());
   // Account for the "implicit" first argument.
   EXPECT_EQ(OutlinedFn->getName(), "__omp_offloading_1_2_parent_l3");
@@ -6378,6 +6439,204 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
   auto *ExitBlock = EntryBlockBranch->getSuccessor(1);
   EXPECT_EQ(ExitBlock->getName(), "worker.exit");
   EXPECT_TRUE(isa<ReturnInst>(ExitBlock->getFirstNonPHI()));
+
+  // Check global exec_mode.
+  GlobalVariable *Used = M->getGlobalVariable("llvm.compiler.used");
+  EXPECT_NE(Used, nullptr);
+  Constant *UsedInit = Used->getInitializer();
+  EXPECT_NE(UsedInit, nullptr);
+  EXPECT_TRUE(isa<ConstantArray>(UsedInit));
+  auto *UsedInitData = cast<ConstantArray>(UsedInit);
+  EXPECT_EQ(1U, UsedInitData->getNumOperands());
+  Constant *ExecMode = UsedInitData->getOperand(0);
+  EXPECT_TRUE(isa<GlobalVariable>(ExecMode));
+  Constant *ExecModeValue = cast<GlobalVariable>(ExecMode)->getInitializer();
+  EXPECT_NE(ExecModeValue, nullptr);
+  EXPECT_TRUE(isa<ConstantInt>(ExecModeValue));
+  EXPECT_EQ(OMP_TGT_EXEC_MODE_GENERIC,
+            cast<ConstantInt>(ExecModeValue)->getZExtValue());
+}
+
+TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  OpenMPIRBuilderConfig Config(/*IsTargetDevice=*/false, /*IsGPU=*/false,
+                               /*OpenMPOffloadMandatory=*/false,
+                               /*HasRequiresReverseOffload=*/false,
+                               /*HasRequiresUnifiedAddress=*/false,
+                               /*HasRequiresUnifiedSharedMemory=*/false,
+                               /*HasRequiresDynamicAllocators=*/false);
+  OMPBuilder.setConfig(Config);
+  F->setName("func");
+  IRBuilder<> Builder(BB);
+
+  auto BodyGenCB = [&](InsertPointTy,
+                       InsertPointTy CodeGenIP) -> InsertPointTy {
+    Builder.restoreIP(CodeGenIP);
+    return Builder.saveIP();
+  };
+
+  auto SimpleArgAccessorCB =
+      [&](llvm::Argument &, llvm::Value *, llvm::Value *&,
+          llvm::OpenMPIRBuilder::InsertPointTy,
+          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
+        Builder.restoreIP(CodeGenIP);
+        return Builder.saveIP();
+      };
+
+  llvm::SmallVector<llvm::Value *> Inputs;
+  llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos;
+  auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy)
+      -> llvm::OpenMPIRBuilder::MapInfosTy & { return CombinedInfos; };
+
+  TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
+  OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
+  OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+      /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
+  RuntimeAttrs.LoopTripCount = Builder.getInt64(1000);
+  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+      OmpLoc, /*IsOffloadEntry=*/true, /*IsSPMD=*/true, Builder.saveIP(),
+      Builder.saveIP(), EntryInfo, DefaultAttrs, RuntimeAttrs, Inputs,
+      GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+  assert(AfterIP && "unexpected error");
+  Builder.restoreIP(*AfterIP);
+  OMPBuilder.finalize();
+  Builder.CreateRetVoid();
+
+  // Check the kernel launch sequence
+  auto Iter = F->getEntryBlock().rbegin();
+  EXPECT_TRUE(isa<BranchInst>(&*(Iter)));
+  BranchInst *Branch = dyn_cast<BranchInst>(&*(Iter));
+  EXPECT_TRUE(isa<CmpInst>(&*(++Iter)));
+  EXPECT_TRUE(isa<CallInst>(&*(++Iter)));
+  CallInst *Call = dyn_cast<CallInst>(&*(Iter));
+
+  // Check that the kernel launch function is called
+  Function *KernelLaunchFunc = Call->getCalledFunction();
+  EXPECT_NE(KernelLaunchFunc, nullptr);
+  StringRef FunctionName = KernelLaunchFunc->getName();
+  EXPECT_TRUE(FunctionName.starts_with("__tgt_target_kernel"));
+
+  // Check the trip count kernel argument (use number 5 starting from the end
+  // and counting the call to __tgt_target_kernel as the first use)
+  Value *KernelArgs = Call->getArgOperand(Call->arg_size() - 1);
+  EXPECT_TRUE(KernelArgs->getNumUses() >= 6);
+  Value *TripCountGetElemPtr = *std::next(KernelArgs->user_begin(), 5);
+  EXPECT_TRUE(isa<GetElementPtrInst>(TripCountGetElemPtr));
+  Value *TripCountStore = TripCountGetElemPtr->getUniqueUndroppableUser();
+  EXPECT_TRUE(isa<StoreInst>(TripCountStore));
+  Value *TripCountStoreArg = 
cast<StoreInst>(TripCountStore)->getValueOperand();
+  EXPECT_TRUE(isa<ConstantInt>(TripCountStoreArg));
+  EXPECT_EQ(1000U, cast<ConstantInt>(TripCountStoreArg)->getZExtValue());
+
+  // Check the fallback call
+  BasicBlock *FallbackBlock = Branch->getSuccessor(0);
+  Iter = FallbackBlock->rbegin();
+  CallInst *FCall = dyn_cast<CallInst>(&*(++Iter));
+  // 'F' has a dummy DISubprogram which causes OutlinedFunc to also
+  // have a DISubprogram. In this case, the call to OutlinedFunc needs
+  // to have a debug loc, otherwise verifier will complain.
+  FCall->setDebugLoc(DL);
+  EXPECT_NE(FCall, nullptr);
+
+  // Check that the outlined function exists with the expected prefix
+  Function *OutlinedFunc = FCall->getCalledFunction();
+  EXPECT_NE(OutlinedFunc, nullptr);
+  StringRef FunctionName2 = OutlinedFunc->getName();
+  EXPECT_TRUE(FunctionName2.starts_with("__omp_offloading"));
+
+  EXPECT_FALSE(verifyModule(*M, &errs()));
+}
+
+TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) {
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.setConfig(
+      OpenMPIRBuilderConfig(/*IsTargetDevice=*/true, /*IsGPU=*/false,
+                            /*OpenMPOffloadMandatory=*/false,
+                            /*HasRequiresReverseOffload=*/false,
+                            /*HasRequiresUnifiedAddress=*/false,
+                            /*HasRequiresUnifiedSharedMemory=*/false,
+                            /*HasRequiresDynamicAllocators=*/false));
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> Builder(BB);
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+
+  Function *OutlinedFn = nullptr;
+  llvm::SmallVector<llvm::Value *> CapturedArgs;
+
+  auto SimpleArgAccessorCB =
+      [&](llvm::Argument &, llvm::Value *, llvm::Value *&,
+          llvm::OpenMPIRBuilder::InsertPointTy,
+          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
+        Builder.restoreIP(CodeGenIP);
+        return Builder.saveIP();
+      };
+
+  llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos;
+  auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy)
+      -> llvm::OpenMPIRBuilder::MapInfosTy & { return CombinedInfos; };
+
+  auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy,
+                       OpenMPIRBuilder::InsertPointTy CodeGenIP)
+      -> OpenMPIRBuilder::InsertPointTy {
+    Builder.restoreIP(CodeGenIP);
+    OutlinedFn = CodeGenIP.getBlock()->getParent();
+    return Builder.saveIP();
+  };
+
+  IRBuilder<>::InsertPoint EntryIP(&F->getEntryBlock(),
+                                   F->getEntryBlock().getFirstInsertionPt());
+  TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
+                                  /*Line=*/3, /*Count=*/0);
+
+  OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+      /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
+  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+      Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/true, EntryIP, EntryIP,
+      EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB,
+      BodyGenCB, SimpleArgAccessorCB);
+  assert(AfterIP && "unexpected error");
+  Builder.restoreIP(*AfterIP);
+
+  Builder.CreateRetVoid();
+  OMPBuilder.finalize();
+
+  // Check outlined function
+  EXPECT_FALSE(verifyModule(*M, &errs()));
+  EXPECT_NE(OutlinedFn, nullptr);
+  EXPECT_NE(F, OutlinedFn);
+
+  // Check that target-cpu and target-features were propagated to the outlined
+  // function
+  EXPECT_EQ(OutlinedFn->getFnAttribute("target-cpu"),
+            F->getFnAttribute("target-cpu"));
+  EXPECT_EQ(OutlinedFn->getFnAttribute("target-features"),
+            F->getFnAttribute("target-features"));
+
+  EXPECT_TRUE(OutlinedFn->hasWeakODRLinkage());
+  // Account for the "implicit" first argument.
+  EXPECT_EQ(OutlinedFn->getName(), "__omp_offloading_1_2_parent_l3");
+  EXPECT_EQ(OutlinedFn->arg_size(), 1U);
+
+  // Check global exec_mode.
+  GlobalVariable *Used = M->getGlobalVariable("llvm.compiler.used");
+  EXPECT_NE(Used, nullptr);
+  Constant *UsedInit = Used->getInitializer();
+  EXPECT_NE(UsedInit, nullptr);
+  EXPECT_TRUE(isa<ConstantArray>(UsedInit));
+  auto *UsedInitData = cast<ConstantArray>(UsedInit);
+  EXPECT_EQ(1U, UsedInitData->getNumOperands());
+  Constant *ExecMode = UsedInitData->getOperand(0);
+  EXPECT_TRUE(isa<GlobalVariable>(ExecMode));
+  Constant *ExecModeValue = cast<GlobalVariable>(ExecMode)->getInitializer();
+  EXPECT_NE(ExecModeValue, nullptr);
+  EXPECT_TRUE(isa<ConstantInt>(ExecModeValue));
+  EXPECT_EQ(OMP_TGT_EXEC_MODE_SPMD,
+            cast<ConstantInt>(ExecModeValue)->getZExtValue());
 }
 
 TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
@@ -6448,9 +6707,11 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
 
   OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
       /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
-      Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs,
-      CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+      Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, EntryIP, EntryIP,
+      EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB,
+      BodyGenCB, SimpleArgAccessorCB);
   assert(AfterIP && "unexpected error");
   Builder.restoreIP(*AfterIP);
 
diff --git 
a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp 
b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d3c3839accb7e7..9bdf3e11496f3a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3936,9 +3936,11 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase 
&builder,
                                         allocaIP, codeGenIP);
   };
 
-  // TODO: Populate default attributes based on the construct and clauses.
+  // TODO: Populate default and runtime attributes based on the construct and
+  // clauses.
   llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs = {
       /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
 
   llvm::SmallVector<llvm::Value *, 4> kernelInput;
   for (size_t i = 0; i < mapVars.size(); ++i) {
@@ -3957,9 +3959,9 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase 
&builder,
 
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
       moduleTranslation.getOpenMPBuilder()->createTarget(
-          ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,
-          defaultAttrs, kernelInput, genMapInfoCB, bodyCB, argAccessorCB, dds,
-          targetOp.getNowait());
+          ompLoc, isOffloadEntry, /*IsSPMD=*/false, allocaIP, builder.saveIP(),
+          entryInfo, defaultAttrs, runtimeAttrs, kernelInput, genMapInfoCB,
+          bodyCB, argAccessorCB, dds, targetOp.getNowait());
 
   if (failed(handleError(afterIP, opInst)))
     return failure();

_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to