llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-openmp

Author: Kevin Sala Penades (kevinsala)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/152830.diff


3 Files Affected:

- (modified) clang/lib/CodeGen/CGOpenMPRuntime.cpp (+26-14) 
- (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+5-2) 
- (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+8-2) 


``````````diff
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp 
b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index a5f2f0efa2c3b..d9121827a813a 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -9489,18 +9489,30 @@ static llvm::Value *emitDeviceID(
   return DeviceID;
 }
 
-static llvm::Value *emitDynCGGroupMem(const OMPExecutableDirective &D,
-                                      CodeGenFunction &CGF) {
-  llvm::Value *DynCGroupMem = CGF.Builder.getInt32(0);
-
-  if (auto *DynMemClause = D.getSingleClause<OMPXDynCGroupMemClause>()) {
-    CodeGenFunction::RunCleanupsScope DynCGroupMemScope(CGF);
-    llvm::Value *DynCGroupMemVal = CGF.EmitScalarExpr(
-        DynMemClause->getSize(), /*IgnoreResultAssign=*/true);
-    DynCGroupMem = CGF.Builder.CreateIntCast(DynCGroupMemVal, CGF.Int32Ty,
-                                             /*isSigned=*/false);
-  }
-  return DynCGroupMem;
+static std::pair<llvm::Value *, bool>
+emitDynCGroupMem(const OMPExecutableDirective &D, CodeGenFunction &CGF) {
+  llvm::Value *DynGP = CGF.Builder.getInt32(0);
+  bool DynGPFallback = false;
+
+  if (auto *DynGPClause = D.getSingleClause<OMPDynGroupprivateClause>()) {
+    CodeGenFunction::RunCleanupsScope DynGPScope(CGF);
+    llvm::Value *DynGPVal =
+        CGF.EmitScalarExpr(DynGPClause->getSize(), 
/*IgnoreResultAssign=*/true);
+    DynGP = CGF.Builder.CreateIntCast(DynGPVal, CGF.Int32Ty,
+                                      /*isSigned=*/false);
+    DynGPFallback = (DynGPClause->getFirstDynGroupprivateModifier() !=
+                         OMPC_DYN_GROUPPRIVATE_strict &&
+                     DynGPClause->getSecondDynGroupprivateModifier() !=
+                         OMPC_DYN_GROUPPRIVATE_strict);
+  } else if (auto *OMPXDynCGClause =
+                 D.getSingleClause<OMPXDynCGroupMemClause>()) {
+    CodeGenFunction::RunCleanupsScope DynCGMemScope(CGF);
+    llvm::Value *DynCGMemVal = CGF.EmitScalarExpr(OMPXDynCGClause->getSize(),
+                                                  /*IgnoreResultAssign=*/true);
+    DynGP = CGF.Builder.CreateIntCast(DynCGMemVal, CGF.Int32Ty,
+                                      /*isSigned=*/false);
+  }
+  return {DynGP, DynGPFallback};
 }
 static void genMapInfoForCaptures(
     MappableExprsHandler &MEHandler, CodeGenFunction &CGF,
@@ -9710,7 +9722,7 @@ static void emitTargetCallKernelLaunch(
     llvm::Value *RTLoc = OMPRuntime->emitUpdateLocation(CGF, D.getBeginLoc());
     llvm::Value *NumIterations =
         OMPRuntime->emitTargetNumIterationsCall(CGF, D, SizeEmitter);
-    llvm::Value *DynCGGroupMem = emitDynCGGroupMem(D, CGF);
+    auto [DynCGroupMem, DynCGroupMemFallback] = emitDynCGroupMem(D, CGF);
     llvm::OpenMPIRBuilder::InsertPointTy AllocaIP(
         CGF.AllocaInsertPt->getParent(), CGF.AllocaInsertPt->getIterator());
 
@@ -9720,7 +9732,7 @@ static void emitTargetCallKernelLaunch(
 
     llvm::OpenMPIRBuilder::TargetKernelArgs Args(
         NumTargetItems, RTArgs, NumIterations, NumTeams, NumThreads,
-        DynCGGroupMem, HasNoWait);
+        DynCGroupMem, HasNoWait, DynCGroupMemFallback);
 
     llvm::OpenMPIRBuilder::InsertPointTy AfterIP =
         cantFail(OMPRuntime->getOMPBuilder().emitKernelLaunch(
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h 
b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 19a4058b64382..ebc50eecb551e 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2341,17 +2341,20 @@ class OpenMPIRBuilder {
     Value *DynCGGroupMem = nullptr;
     /// True if the kernel has 'no wait' clause.
     bool HasNoWait = false;
+    /// True if the dynamic shared memory may fallback.
+    bool MayFallbackDynCGroupMem = false;
 
     // Constructors for TargetKernelArgs.
     TargetKernelArgs() {}
     TargetKernelArgs(unsigned NumTargetItems, TargetDataRTArgs RTArgs,
                      Value *NumIterations, ArrayRef<Value *> NumTeams,
                      ArrayRef<Value *> NumThreads, Value *DynCGGroupMem,
-                     bool HasNoWait)
+                     bool HasNoWait, bool MayFallbackDynCGroupMem)
         : NumTargetItems(NumTargetItems), RTArgs(RTArgs),
           NumIterations(NumIterations), NumTeams(NumTeams),
           NumThreads(NumThreads), DynCGGroupMem(DynCGGroupMem),
-          HasNoWait(HasNoWait) {}
+          HasNoWait(HasNoWait),
+          MayFallbackDynCGroupMem(MayFallbackDynCGroupMem) {}
   };
 
   /// Create the kernel args vector used by emitTargetKernel. This function
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp 
b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 170224616ac64..e600508d347cb 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -506,7 +506,13 @@ void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs 
&KernelArgs,
   auto Int32Ty = Type::getInt32Ty(Builder.getContext());
   constexpr const size_t MaxDim = 3;
   Value *ZeroArray = Constant::getNullValue(ArrayType::get(Int32Ty, MaxDim));
-  Value *Flags = Builder.getInt64(KernelArgs.HasNoWait);
+
+  Value *HasNoWaitFlag = Builder.getInt64(KernelArgs.HasNoWait);
+  Value *MayFallbackDynCGroupMemFlag =
+      Builder.getInt64(KernelArgs.MayFallbackDynCGroupMem);
+  MayFallbackDynCGroupMemFlag =
+      Builder.CreateShl(MayFallbackDynCGroupMemFlag, 2);
+  Value *Flags = Builder.CreateOr(HasNoWaitFlag, MayFallbackDynCGroupMemFlag);
 
   assert(!KernelArgs.NumTeams.empty() && !KernelArgs.NumThreads.empty());
 
@@ -7891,7 +7897,7 @@ static void emitTargetCall(
 
     KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, 
TripCount,
                                               NumTeamsC, NumThreadsC,
-                                              DynCGGroupMem, HasNoWait);
+                                              DynCGGroupMem, HasNoWait, false);
 
     // Assume no error was returned because TaskBodyCB and
     // EmitTargetCallFallbackCB don't produce any.

``````````

</details>


https://github.com/llvm/llvm-project/pull/152830
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to