https://github.com/skatrak updated 
https://github.com/llvm/llvm-project/pull/161864

>From 33e2524e060ecbdece8f91a3186d46088aa24140 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <[email protected]>
Date: Tue, 16 Sep 2025 14:18:39 +0100
Subject: [PATCH] [MLIR][OpenMP][OMPIRBuilder] Improve shared memory checks

This patch refines checks to decide whether to use device shared memory or
regular stack allocations. In particular, it adds support for parallel regions
residing on standalone target device functions.

The changes are:
- Shared memory is introduced for `omp.target` implicit allocations, such as
those related to privatization and mapping, as long as they are shared across
threads in a nested parallel region.
- Standalone target device functions are interpreted as being part of a Generic
kernel, since the fact that they are present in the module after filtering
means they must be reachable from a target region.
- Prevent allocations whose only shared uses inside of an `omp.parallel` region
are as part of a `private` clause from being moved to device shared memory.
---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |   4 +-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     |  28 ++--
 .../Frontend/OpenMPIRBuilderTest.cpp          |  55 +++++---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 122 ++++++++++++------
 .../LLVMIR/omptarget-parallel-llvm.mlir       |   8 +-
 .../LLVMIR/omptarget-parallel-wsloop.mlir     |   7 +-
 .../fortran/target-generic-outlined-loops.f90 | 109 ++++++++++++++++
 7 files changed, 258 insertions(+), 75 deletions(-)
 create mode 100644 
offload/test/offloading/fortran/target-generic-outlined-loops.f90

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h 
b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index d8e5f8cf5a45e..410912ba375a3 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -3292,8 +3292,8 @@ class OpenMPIRBuilder {
       ArrayRef<InsertPointTy> DeallocIPs)>;
 
   using TargetGenArgAccessorsCallbackTy = function_ref<InsertPointOrErrorTy(
-      Argument &Arg, Value *Input, Value *&RetVal, InsertPointTy AllocaIP,
-      InsertPointTy CodeGenIP)>;
+      Argument &Arg, Value *Input, Value *&RetVal, InsertPointTy AllocIP,
+      InsertPointTy CodeGenIP, ArrayRef<InsertPointTy> DeallocIPs)>;
 
   /// Generator for '#omp target'
   ///
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp 
b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index a18db939b5876..c164d32f8f98c 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -312,6 +312,12 @@ getTargetKernelExecMode(Function &Kernel) {
   return static_cast<OMPTgtExecModeFlags>(KernelMode->getZExtValue());
 }
 
+static bool isGenericKernel(Function &Fn) {
+  std::optional<omp::OMPTgtExecModeFlags> ExecMode =
+      getTargetKernelExecMode(Fn);
+  return !ExecMode || (*ExecMode & OMP_TGT_EXEC_MODE_GENERIC);
+}
+
 /// Make \p Source branch to \p Target.
 ///
 /// Handles two situations:
@@ -1535,11 +1541,9 @@ static void targetParallelCallback(
       IfCondition ? Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32)
                   : Builder.getInt32(1);
 
-  // If this is not a Generic kernel, we can skip generating the wrapper.
-  std::optional<omp::OMPTgtExecModeFlags> ExecMode =
-      getTargetKernelExecMode(*OuterFn);
+  // If this is a Generic kernel, we can generate the wrapper.
   Value *WrapperFn;
-  if (ExecMode && (*ExecMode & OMP_TGT_EXEC_MODE_GENERIC))
+  if (isGenericKernel(*OuterFn))
     WrapperFn = createTargetParallelWrapper(OMPIRBuilder, OutlinedFn);
   else
     WrapperFn = Constant::getNullValue(PtrTy);
@@ -1812,13 +1816,10 @@ OpenMPIRBuilder::InsertPointOrErrorTy 
OpenMPIRBuilder::createParallel(
 
   auto OI = [&]() -> std::unique_ptr<OutlineInfo> {
     if (Config.isTargetDevice()) {
-      std::optional<omp::OMPTgtExecModeFlags> ExecMode =
-          getTargetKernelExecMode(*OuterFn);
-
-      // If OuterFn is not a Generic kernel, skip custom allocation. This 
causes
-      // the CodeExtractor to follow its default behavior. Otherwise, we need 
to
-      // use device shared memory to allocate argument structures.
-      if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC)
+      // If OuterFn is a Generic kernel, we need to use device shared memory to
+      // allocate argument structures. Otherwise, we use stack allocations as
+      // usual.
+      if (isGenericKernel(*OuterFn))
         return std::make_unique<DeviceSharedMemOutlineInfo>(*this);
     }
     return std::make_unique<OutlineInfo>();
@@ -7806,8 +7807,9 @@ static Expected<Function *> createOutlinedFunction(
     Argument &Arg = std::get<1>(InArg);
     Value *InputCopy = nullptr;
 
-    llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
-        ArgAccessorFuncCB(Arg, Input, InputCopy, AllocaIP, Builder.saveIP());
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = ArgAccessorFuncCB(
+        Arg, Input, InputCopy, AllocaIP, Builder.saveIP(),
+        OpenMPIRBuilder::InsertPointTy(ExitBB, ExitBB->begin()));
     if (!AfterIP)
       return AfterIP.takeError();
     Builder.restoreIP(*AfterIP);
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp 
b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 1e5b8145d5cdc..d231a778a8a97 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -745,8 +745,10 @@ TEST_F(OpenMPIRBuilderTest, ParallelSimpleGPU) {
   EXPECT_EQ(OutlinedFn->getArg(2)->getType(),
             PointerType::get(M->getContext(), 0));
   EXPECT_EQ(&OutlinedFn->getEntryBlock(), PrivAI->getParent());
-  EXPECT_TRUE(OutlinedFn->hasOneUse());
-  User *Usr = OutlinedFn->user_back();
+  EXPECT_TRUE(OutlinedFn->hasNUses(2));
+  User *Usr = *OutlinedFn->users().begin();
+  User *WrapperUsr = *++OutlinedFn->users().begin();
+
   ASSERT_TRUE(isa<CallInst>(Usr));
   CallInst *Parallel51CI = dyn_cast<CallInst>(Usr);
   ASSERT_NE(Parallel51CI, nullptr);
@@ -757,6 +759,20 @@ TEST_F(OpenMPIRBuilderTest, ParallelSimpleGPU) {
   EXPECT_TRUE(
       
isa<GlobalVariable>(Parallel51CI->getArgOperand(0)->stripPointerCasts()));
   EXPECT_EQ(Parallel51CI, Usr);
+
+  ASSERT_TRUE(isa<CallInst>(WrapperUsr));
+  CallInst *OutlinedCI = dyn_cast<CallInst>(WrapperUsr);
+  ASSERT_NE(OutlinedCI, nullptr);
+  EXPECT_EQ(OutlinedCI->getCalledFunction(), OutlinedFn);
+
+  Function *WrapperFn = OutlinedCI->getFunction();
+  EXPECT_TRUE(WrapperFn->hasInternalLinkage());
+  EXPECT_EQ(WrapperFn->arg_size(), 2U);
+  EXPECT_EQ(WrapperFn->getArg(0)->getType(),
+            IntegerType::getInt16Ty(M->getContext()));
+  EXPECT_EQ(WrapperFn->getArg(1)->getType(),
+            IntegerType::getInt32Ty(M->getContext()));
+
   M->setDataLayout(oldDLStr);
 }
 
@@ -6396,7 +6412,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   auto SimpleArgAccessorCB =
       [&](llvm::Argument &Arg, llvm::Value *Input, llvm::Value *&RetVal,
           llvm::OpenMPIRBuilder::InsertPointTy AllocaIP,
-          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
+          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP,
+          llvm::ArrayRef<llvm::OpenMPIRBuilder::InsertPointTy> DeallocIPs) {
         IRBuilderBase::InsertPointGuard guard(Builder);
         Builder.SetCurrentDebugLocation(llvm::DebugLoc());
         if (!OMPBuilder.Config.isTargetDevice()) {
@@ -6561,7 +6578,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
   auto SimpleArgAccessorCB =
       [&](llvm::Argument &Arg, llvm::Value *Input, llvm::Value *&RetVal,
           llvm::OpenMPIRBuilder::InsertPointTy AllocaIP,
-          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
+          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP,
+          llvm::ArrayRef<llvm::OpenMPIRBuilder::InsertPointTy> DeallocIPs) {
         IRBuilderBase::InsertPointGuard guard(Builder);
         Builder.SetCurrentDebugLocation(llvm::DebugLoc());
         if (!OMPBuilder.Config.isTargetDevice()) {
@@ -6763,12 +6781,13 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) {
     return Builder.saveIP();
   };
 
-  auto SimpleArgAccessorCB = [&](Argument &, Value *, Value *&,
-                                 OpenMPIRBuilder::InsertPointTy,
-                                 OpenMPIRBuilder::InsertPointTy CodeGenIP) {
-    Builder.restoreIP(CodeGenIP);
-    return Builder.saveIP();
-  };
+  auto SimpleArgAccessorCB =
+      [&](Argument &, Value *, Value *&, OpenMPIRBuilder::InsertPointTy,
+          OpenMPIRBuilder::InsertPointTy CodeGenIP,
+          llvm::ArrayRef<llvm::OpenMPIRBuilder::InsertPointTy>) {
+        Builder.restoreIP(CodeGenIP);
+        return Builder.saveIP();
+      };
 
   SmallVector<Value *> Inputs;
   OpenMPIRBuilder::MapInfosTy CombinedInfos;
@@ -6862,12 +6881,13 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) {
   Function *OutlinedFn = nullptr;
   SmallVector<Value *> CapturedArgs;
 
-  auto SimpleArgAccessorCB = [&](Argument &, Value *, Value *&,
-                                 OpenMPIRBuilder::InsertPointTy,
-                                 OpenMPIRBuilder::InsertPointTy CodeGenIP) {
-    Builder.restoreIP(CodeGenIP);
-    return Builder.saveIP();
-  };
+  auto SimpleArgAccessorCB =
+      [&](Argument &, Value *, Value *&, OpenMPIRBuilder::InsertPointTy,
+          OpenMPIRBuilder::InsertPointTy CodeGenIP,
+          llvm::ArrayRef<llvm::OpenMPIRBuilder::InsertPointTy>) {
+        Builder.restoreIP(CodeGenIP);
+        return Builder.saveIP();
+      };
 
   OpenMPIRBuilder::MapInfosTy CombinedInfos;
   auto GenMapInfoCB =
@@ -6961,7 +6981,8 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
   auto SimpleArgAccessorCB =
       [&](llvm::Argument &Arg, llvm::Value *Input, llvm::Value *&RetVal,
           llvm::OpenMPIRBuilder::InsertPointTy AllocaIP,
-          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
+          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP,
+          llvm::ArrayRef<llvm::OpenMPIRBuilder::InsertPointTy> DeallocIPs) {
         IRBuilderBase::InsertPointGuard guard(Builder);
         Builder.SetCurrentDebugLocation(llvm::DebugLoc());
         if (!OMPBuilder.Config.isTargetDevice()) {
diff --git 
a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp 
b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 3accca891ba9c..1183865dc3645 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1128,9 +1128,10 @@ struct DeferredStore {
 } // namespace
 
 /// Check whether allocations for the given operation might potentially have to
-/// be done in device shared memory. That means we're compiling for a 
offloading
-/// target, the operation is an `omp::TargetOp` or nested inside of one and 
that
-/// target region represents a Generic (non-SPMD) kernel.
+/// be done in device shared memory. That means we're compiling for an
+/// offloading target, the operation is neither an `omp::TargetOp` nor nested
+/// inside of one, or it is and that target region represents a Generic
+/// (non-SPMD) kernel.
 ///
 /// This represents a necessary but not sufficient set of conditions to use
 /// device shared memory in place of regular allocas. For some variables, the
@@ -1146,7 +1147,7 @@ mightAllocInDeviceSharedMemory(Operation &op,
   if (!targetOp)
     targetOp = op.getParentOfType<omp::TargetOp>();
 
-  return targetOp &&
+  return !targetOp ||
          targetOp.getKernelExecFlags(targetOp.getInnermostCapturedOmpOp()) ==
              omp::TargetExecMode::generic;
 }
@@ -1160,18 +1161,36 @@ mightAllocInDeviceSharedMemory(Operation &op,
 /// operation that owns the specified block argument.
 static bool mustAllocPrivateVarInDeviceSharedMemory(BlockArgument value) {
   Operation *parentOp = value.getOwner()->getParentOp();
-  auto targetOp = dyn_cast<omp::TargetOp>(parentOp);
-  if (!targetOp)
-    targetOp = parentOp->getParentOfType<omp::TargetOp>();
-  assert(targetOp && "expected a parent omp.target operation");
-
+  auto moduleOp = parentOp->getParentOfType<ModuleOp>();
   for (auto *user : value.getUsers()) {
     if (auto parallelOp = dyn_cast<omp::ParallelOp>(user)) {
       if (llvm::is_contained(parallelOp.getReductionVars(), value))
         return true;
     } else if (auto parallelOp = user->getParentOfType<omp::ParallelOp>()) {
-      if (parentOp->isProperAncestor(parallelOp))
-        return true;
+      if (parentOp->isProperAncestor(parallelOp)) {
+        // If it is used directly inside of a parallel region, skip private
+        // clause uses.
+        bool isPrivateClauseUse = false;
+        if (auto argIface = dyn_cast<omp::BlockArgOpenMPOpInterface>(user)) {
+          if (auto privateSyms = llvm::cast_or_null<ArrayAttr>(
+                  user->getAttr("private_syms"))) {
+            for (auto [var, sym] :
+                 llvm::zip_equal(argIface.getPrivateVars(), privateSyms)) {
+              if (var != value)
+                continue;
+
+              auto privateOp = cast<omp::PrivateClauseOp>(
+                  moduleOp.lookupSymbol(cast<SymbolRefAttr>(sym)));
+              if (privateOp.getCopyRegion().empty()) {
+                isPrivateClauseUse = true;
+                break;
+              }
+            }
+          }
+        }
+        if (!isPrivateClauseUse)
+          return true;
+      }
     }
   }
 
@@ -1196,8 +1215,8 @@ allocReductionVars(T op, ArrayRef<BlockArgument> 
reductionArgs,
   builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
 
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
-  bool useDeviceSharedMem =
-      isa<omp::TeamsOp>(op) && mightAllocInDeviceSharedMemory(*op, 
*ompBuilder);
+  bool useDeviceSharedMem = isa<omp::TeamsOp>(*op) &&
+                            mightAllocInDeviceSharedMemory(*op, *ompBuilder);
 
   // delay creating stores until after all allocas
   deferredStores.reserve(op.getNumReductionVars());
@@ -1318,8 +1337,8 @@ initReductionVars(OP op, ArrayRef<BlockArgument> 
reductionArgs,
     return success();
 
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
-  bool useDeviceSharedMem =
-      isa<omp::TeamsOp>(op) && mightAllocInDeviceSharedMemory(*op, 
*ompBuilder);
+  bool useDeviceSharedMem = isa<omp::TeamsOp>(*op) &&
+                            mightAllocInDeviceSharedMemory(*op, *ompBuilder);
 
   llvm::BasicBlock *initBlock = splitBB(builder, true, "omp.reduction.init");
   auto allocaIP = llvm::IRBuilderBase::InsertPoint(
@@ -1540,8 +1559,8 @@ static LogicalResult createReductionsAndCleanup(
       reductionRegions, privateReductionVariables, moduleTranslation, builder,
       "omp.reduction.cleanup");
 
-  bool useDeviceSharedMem =
-      isa<omp::TeamsOp>(op) && mightAllocInDeviceSharedMemory(*op, 
*ompBuilder);
+  bool useDeviceSharedMem = isa<omp::TeamsOp>(*op) &&
+                            mightAllocInDeviceSharedMemory(*op, *ompBuilder);
   if (useDeviceSharedMem) {
     for (auto [var, reductionDecl] :
          llvm::zip_equal(privateReductionVariables, reductionDecls))
@@ -1721,7 +1740,7 @@ allocatePrivateVars(T op, llvm::IRBuilderBase &builder,
 
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
   bool mightUseDeviceSharedMem =
-      isa<omp::TeamsOp, omp::DistributeOp>(*op) &&
+      isa<omp::TargetOp, omp::TeamsOp, omp::DistributeOp>(*op) &&
       mightAllocInDeviceSharedMemory(*op, *ompBuilder);
   unsigned int allocaAS =
       moduleTranslation.getLLVMModule()->getDataLayout().getAllocaAddrSpace();
@@ -1839,7 +1858,7 @@ cleanupPrivateVars(T op, llvm::IRBuilderBase &builder,
 
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
   bool mightUseDeviceSharedMem =
-      isa<omp::TeamsOp, omp::DistributeOp>(*op) &&
+      isa<omp::TargetOp, omp::TeamsOp, omp::DistributeOp>(*op) &&
       mightAllocInDeviceSharedMemory(*op, *ompBuilder);
   for (auto [privDecl, llvmPrivVar, blockArg] :
        llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.llvmVars,
@@ -5265,42 +5284,68 @@ handleDeclareTargetMapVar(MapInfoData &mapData,
 // a store of the kernel argument into this allocated memory which
 // will then be loaded from, ByCopy will use the allocated memory
 // directly.
-static llvm::IRBuilderBase::InsertPoint
-createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
-                             llvm::Value *input, llvm::Value *&retVal,
-                             llvm::IRBuilderBase &builder,
-                             llvm::OpenMPIRBuilder &ompBuilder,
-                             LLVM::ModuleTranslation &moduleTranslation,
-                             llvm::IRBuilderBase::InsertPoint allocaIP,
-                             llvm::IRBuilderBase::InsertPoint codeGenIP) {
+static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(
+    omp::TargetOp targetOp, MapInfoData &mapData, llvm::Argument &arg,
+    llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder,
+    llvm::OpenMPIRBuilder &ompBuilder,
+    LLVM::ModuleTranslation &moduleTranslation,
+    llvm::IRBuilderBase::InsertPoint allocIP,
+    llvm::IRBuilderBase::InsertPoint codeGenIP,
+    llvm::ArrayRef<llvm::IRBuilderBase::InsertPoint> deallocIPs) {
   assert(ompBuilder.Config.isTargetDevice() &&
          "function only supported for target device codegen");
-  builder.restoreIP(allocaIP);
+  builder.restoreIP(allocIP);
 
   omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
   LLVM::TypeToLLVMIRTranslator typeToLLVMIRTranslator(
       ompBuilder.M.getContext());
   unsigned alignmentValue = 0;
+  BlockArgument mlirArg;
   // Find the associated MapInfoData entry for the current input
-  for (size_t i = 0; i < mapData.MapClause.size(); ++i)
+  for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
     if (mapData.OriginalValue[i] == input) {
       auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
       capture = mapOp.getMapCaptureType();
       // Get information of alignment of mapped object
       alignmentValue = typeToLLVMIRTranslator.getPreferredAlignment(
           mapOp.getVarType(), ompBuilder.M.getDataLayout());
+      // Get the corresponding target entry block argument
+      mlirArg =
+          cast<omp::BlockArgOpenMPOpInterface>(*targetOp).getMapBlockArgs()[i];
       break;
     }
+  }
 
   unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
   unsigned int defaultAS =
       ompBuilder.M.getDataLayout().getProgramAddressSpace();
 
-  // Create the alloca for the argument the current point.
-  llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
+  // Create the allocation for the argument.
+  llvm::Value *v = nullptr;
+  if (mightAllocInDeviceSharedMemory(*targetOp, ompBuilder) &&
+      mustAllocPrivateVarInDeviceSharedMemory(mlirArg)) {
+    // Use the beginning of the codeGenIP rather than the usual allocation 
point
+    // for shared memory allocations because otherwise these would be done 
prior
+    // to the target initialization call. Also, the exit block (where the
+    // deallocation is placed) is only executed if the initialization call
+    // succeeds.
+    builder.SetInsertPoint(codeGenIP.getBlock()->getFirstInsertionPt());
+    v = ompBuilder.createOMPAllocShared(builder, arg.getType());
+
+    // Create deallocations in all provided deallocation points and then 
restore
+    // the insertion point to right after the new allocations.
+    llvm::IRBuilderBase::InsertPointGuard guard(builder);
+    for (auto deallocIP : deallocIPs) {
+      builder.SetInsertPoint(deallocIP.getBlock(), deallocIP.getPoint());
+      ompBuilder.createOMPFreeShared(builder, v, arg.getType());
+    }
+  } else {
+    // Use the current point, which was previously set to allocIP.
+    v = builder.CreateAlloca(arg.getType(), allocaAS);
 
-  if (allocaAS != defaultAS && arg.getType()->isPointerTy())
-    v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
+    if (allocaAS != defaultAS && arg.getType()->isPointerTy())
+      v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
+  }
 
   builder.CreateStore(&arg, v);
 
@@ -5890,8 +5935,9 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase 
&builder,
   };
 
   auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
-                           llvm::Value *&retVal, InsertPointTy allocaIP,
-                           InsertPointTy codeGenIP)
+                           llvm::Value *&retVal, InsertPointTy allocIP,
+                           InsertPointTy codeGenIP,
+                           llvm::ArrayRef<InsertPointTy> deallocIPs)
       -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
     llvm::IRBuilderBase::InsertPointGuard guard(builder);
     builder.SetCurrentDebugLocation(llvm::DebugLoc());
@@ -5905,9 +5951,9 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase 
&builder,
       return codeGenIP;
     }
 
-    return createDeviceArgumentAccessor(mapData, arg, input, retVal, builder,
-                                        *ompBuilder, moduleTranslation,
-                                        allocaIP, codeGenIP);
+    return createDeviceArgumentAccessor(targetOp, mapData, arg, input, retVal,
+                                        builder, *ompBuilder, 
moduleTranslation,
+                                        allocIP, codeGenIP, deallocIPs);
   };
 
   llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
diff --git a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir 
b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
index c3ce2f62c486e..c1a7c9736910a 100644
--- a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
@@ -55,15 +55,14 @@ module attributes {dlti.dl_spec = 
#dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
 // CHECK: define weak_odr protected amdgpu_kernel void @[[FUNC0:.*]](
 // CHECK-SAME: ptr %[[TMP:.*]], ptr %[[TMP0:.*]]) #{{[0-9]+}} {
 // CHECK:         %[[TMP1:.*]] = alloca [1 x ptr], align 8, addrspace(5)
-// CHECK:         %[[TMP2:.*]] = alloca ptr, align 8, addrspace(5)
-// CHECK:         %[[TMP3:.*]] = addrspacecast ptr addrspace(5) %[[TMP2]] to 
ptr
-// CHECK:         store ptr %[[TMP0]], ptr %[[TMP3]], align 8
 // CHECK:         %[[TMP4:.*]] = call i32 @__kmpc_target_init(ptr 
addrspacecast (ptr addrspace(1) @{{.*}} to ptr), ptr %[[TMP]])
 // CHECK:         %[[EXEC_USER_CODE:.*]] = icmp eq i32 %[[TMP4]], -1
 // CHECK:         br i1 %[[EXEC_USER_CODE]], label %[[USER_CODE_ENTRY:.*]], 
label %[[WORKER_EXIT:.*]]
 // CHECK:         %[[TMP5:.*]] = addrspacecast ptr addrspace(5) %[[TMP1]] to 
ptr
 // CHECK:         %[[STRUCTARG:.*]] = call align 8 ptr 
@__kmpc_alloc_shared(i64 8)
-// CHECK:         %[[TMP6:.*]] = load ptr, ptr %[[TMP3]], align 8
+// CHECK:         %[[TMP2:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 8)
+// CHECK:         store ptr %[[TMP0]], ptr %[[TMP2]], align 8
+// CHECK:         %[[TMP6:.*]] = load ptr, ptr %[[TMP2]], align 8
 // CHECK:         %[[OMP_GLOBAL_THREAD_NUM:.*]] = call i32 
@__kmpc_global_thread_num(ptr addrspacecast (ptr addrspace(1) @[[GLOB1:[0-9]+]] 
to ptr))
 // CHECK:         %[[GEP_:.*]] = getelementptr { ptr }, ptr %[[STRUCTARG]], 
i32 0, i32 0
 // CHECK:         store ptr %[[TMP6]], ptr %[[GEP_]], align 8
@@ -71,6 +70,7 @@ module attributes {dlti.dl_spec = 
#dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
 // CHECK:         store ptr %[[STRUCTARG]], ptr %[[TMP7]], align 8
 // CHECK:         call void @__kmpc_parallel_51(ptr addrspacecast (ptr 
addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, 
i32 -1, ptr @[[FUNC1:.*]], ptr @[[FUNC1_WRAPPER:.*]], ptr %[[TMP5]], i64 1)
 // CHECK:         call void @__kmpc_free_shared(ptr %[[STRUCTARG]], i64 8)
+// CHECK:         call void @__kmpc_free_shared(ptr %[[TMP2]], i64 8)
 // CHECK:         call void @__kmpc_target_deinit()
 
 // CHECK: define internal void @[[FUNC1]](
diff --git a/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir 
b/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir
index 5d2861a5d0f35..a8e770ec1d49b 100644
--- a/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir
@@ -29,7 +29,7 @@ module attributes {dlti.dl_spec = 
#dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
 // CHECK:      call void @__kmpc_parallel_51(ptr addrspacecast
 // CHECK-SAME:  (ptr addrspace(1) @[[GLOB:[0-9]+]] to ptr),
 // CHECK-SAME:  i32 %[[THREAD_NUM:.*]], i32 1, i32 -1, i32 -1,
-// CHECK-SAME:  ptr @[[PARALLEL_FUNC:.*]], ptr null, ptr 
%[[PARALLEL_ARGS:.*]], i64 1)
+// CHECK-SAME:  ptr @[[PARALLEL_FUNC:.*]], ptr @[[PARALLEL_WRAPPER:.*]], ptr 
%[[PARALLEL_ARGS:.*]], i64 1)
 
 // CHECK:      define internal void @[[PARALLEL_FUNC]]
 // CHECK-SAME:  (ptr noalias noundef %[[TID_ADDR:.*]], ptr noalias noundef 
%[[ZERO_ADDR:.*]],
@@ -41,6 +41,11 @@ module attributes {dlti.dl_spec = 
#dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
 
 // CHECK:      define internal void @[[LOOP_BODY_FUNC]](i32 %[[CNT:.*]], ptr 
%[[LOOP_BODY_ARG_PTR:.*]]) #[[ATTRS2:[0-9]+]] {
 
+// CHECK:      define internal void @[[PARALLEL_WRAPPER]](i16 {{.*}}, i32 
{{.*}}) {
+// CHECK-NOT:    ret {{.*}}
+// CHECK:        call void @[[PARALLEL_FUNC]]({{.*}})
+// CHECK-NEXT:   ret void
+
 // CHECK:      attributes #[[ATTRS2]] = {
 // CHECK-SAME:  "target-cpu"="gfx90a"
 // CHECK-SAME:  "target-features"="+gfx9-insts,+wavefrontsize64"
diff --git a/offload/test/offloading/fortran/target-generic-outlined-loops.f90 
b/offload/test/offloading/fortran/target-generic-outlined-loops.f90
new file mode 100644
index 0000000000000..594809027e115
--- /dev/null
+++ b/offload/test/offloading/fortran/target-generic-outlined-loops.f90
@@ -0,0 +1,109 @@
+! Offloading test for generic target regions containing different kinds of
+! loop constructs inside, moving parallel regions into a separate subroutine.
+! REQUIRES: flang, amdgpu
+
+! RUN: %libomptarget-compile-fortran-run-and-check-generic
+subroutine parallel_loop(n, counter)
+  implicit none
+  integer, intent(in) :: n
+  integer, intent(inout) :: counter
+  integer :: i
+
+  !$omp parallel do reduction(+:counter)
+  do i=1, n
+    counter = counter + 1
+  end do
+end subroutine
+
+program main
+  integer :: i1, i2, n1, n2, counter
+
+  n1 = 100
+  n2 = 50
+
+  counter = 0
+  !$omp target map(tofrom:counter)
+    !$omp teams distribute reduction(+:counter)
+    do i1=1, n1
+      counter = counter + 1
+    end do
+  !$omp end target
+
+  ! CHECK: 1 100
+  print '(I2" "I0)', 1, counter
+
+  counter = 0
+  !$omp target map(tofrom:counter)
+    call parallel_loop(n1, counter)
+    call parallel_loop(n1, counter)
+  !$omp end target
+
+  ! CHECK: 2 200
+  print '(I2" "I0)', 2, counter
+
+  counter = 0
+  !$omp target map(tofrom:counter)
+    counter = counter + 1
+    call parallel_loop(n1, counter)
+    counter = counter + 1
+    call parallel_loop(n1, counter)
+    counter = counter + 1
+  !$omp end target
+
+  ! CHECK: 3 203
+  print '(I2" "I0)', 3, counter
+
+  counter = 0
+  !$omp target map(tofrom: counter)
+    counter = counter + 1
+    call parallel_loop(n1, counter)
+    counter = counter + 1
+  !$omp end target
+
+  ! CHECK: 4 102
+  print '(I2" "I0)', 4, counter
+
+
+  counter = 0
+  !$omp target teams distribute reduction(+:counter)
+  do i1=1, n1
+    call parallel_loop(n2, counter)
+  end do
+
+  ! CHECK: 5 5000
+  print '(I2" "I0)', 5, counter
+
+  counter = 0
+  !$omp target teams distribute reduction(+:counter)
+  do i1=1, n1
+    counter = counter + 1
+    call parallel_loop(n2, counter)
+    counter = counter + 1
+  end do
+
+  ! CHECK: 6 5200
+  print '(I2" "I0)', 6, counter
+
+  counter = 0
+  !$omp target teams distribute reduction(+:counter)
+  do i1=1, n1
+    call parallel_loop(n2, counter)
+    call parallel_loop(n2, counter)
+  end do
+
+  ! CHECK: 7 10000
+  print '(I2" "I0)', 7, counter
+
+  counter = 0
+  !$omp target teams distribute reduction(+:counter)
+  do i1=1, n1
+    counter = counter + 1
+    call parallel_loop(n2, counter)
+    counter = counter + 1
+    call parallel_loop(n2, counter)
+    counter = counter + 1
+  end do
+
+  ! CHECK: 8 10300
+  print '(I2" "I0)', 8, counter
+end program

_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to