llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Sergio Afonso (skatrak)

<details>
<summary>Changes</summary>

Argument structures are created when sections of the LLVM IR corresponding to 
an OpenMP construct are outlined into their own function. For this, stack 
allocations are used.

This patch modifies this behavior when compiling for a target device and 
outlining `parallel`-related IR, so that it uses device shared memory instead 
of private stack space. This is needed in order for threads to have access to 
these arguments.

---
Full diff: https://github.com/llvm/llvm-project/pull/150925.diff


5 Files Affected:

- (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+6) 
- (modified) llvm/include/llvm/Transforms/Utils/CodeExtractor.h (+34-5) 
- (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+89-9) 
- (modified) llvm/lib/Transforms/Utils/CodeExtractor.cpp (+56-17) 
- (modified) mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir (+5-5) 


``````````diff
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h 
b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 110b0fde863c5..967fe38c0d635 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2159,7 +2159,13 @@ class OpenMPIRBuilder {
   /// during finalization.
   struct OutlineInfo {
     using PostOutlineCBTy = std::function<void(Function &)>;
+    using CustomArgAllocatorCBTy = std::function<Instruction *(
+        BasicBlock *, BasicBlock::iterator, Type *, const Twine &)>;
+    using CustomArgDeallocatorCBTy = std::function<Instruction *(
+        BasicBlock *, BasicBlock::iterator, Value *, Type *)>;
     PostOutlineCBTy PostOutlineCB;
+    CustomArgAllocatorCBTy CustomArgAllocatorCB;
+    CustomArgDeallocatorCBTy CustomArgDeallocatorCB;
     BasicBlock *EntryBB, *ExitBB, *OuterAllocaBB;
     SmallVector<Value *, 2> ExcludeArgsFromAggregate;
 
diff --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h 
b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
index 407eb50d2c7a3..cc472a5bf3576 100644
--- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
+++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
@@ -17,6 +17,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SetVector.h"
+#include "llvm/IR/BasicBlock.h"
 #include "llvm/Support/Compiler.h"
 #include <limits>
 
@@ -24,7 +25,6 @@ namespace llvm {
 
 template <typename PtrType> class SmallPtrSetImpl;
 class AllocaInst;
-class BasicBlock;
 class BlockFrequency;
 class BlockFrequencyInfo;
 class BranchProbabilityInfo;
@@ -85,6 +85,10 @@ class CodeExtractorAnalysisCache {
   /// 3) Add allocas for any scalar outputs, adding all of the outputs' allocas
   ///    as arguments, and inserting stores to the arguments for any scalars.
   class CodeExtractor {
+    using CustomArgAllocatorCBTy = std::function<Instruction *(
+        BasicBlock *, BasicBlock::iterator, Type *, const Twine &)>;
+    using CustomArgDeallocatorCBTy = std::function<Instruction *(
+        BasicBlock *, BasicBlock::iterator, Value *, Type *)>;
     using ValueSet = SetVector<Value *>;
 
     // Various bits of state computed on construction.
@@ -133,6 +137,25 @@ class CodeExtractorAnalysisCache {
     // space.
     bool ArgsInZeroAddressSpace;
 
+    // If set, this callback will be used to allocate the arguments in the
+    // caller before passing it to the outlined function holding the extracted
+    // piece of code.
+    CustomArgAllocatorCBTy *CustomArgAllocatorCB;
+
+    // A block outside of the extraction set where previously introduced
+    // intermediate allocations can be deallocated. This is only used when an
+    // custom deallocator is specified.
+    BasicBlock *DeallocationBlock;
+
+    // If set, this callback will be used to deallocate the arguments in the
+    // caller after running the outlined function holding the extracted piece 
of
+    // code. It will not be called if a custom allocator isn't also present.
+    //
+    // By default, this will be done at the end of the basic block containing
+    // the call to the outlined function, except if a deallocation block is
+    // specified. In that case, that will take precedence.
+    CustomArgDeallocatorCBTy *CustomArgDeallocatorCB;
+
   public:
     /// Create a code extractor for a sequence of blocks.
     ///
@@ -149,7 +172,9 @@ class CodeExtractorAnalysisCache {
     /// the function from which the code is being extracted.
     /// If ArgsInZeroAddressSpace param is set to true, then the aggregate
     /// param pointer of the outlined function is declared in zero address
-    /// space.
+    /// space. If a CustomArgAllocatorCB callback is specified, it will be used
+    /// to allocate any structures or variable copies needed to pass arguments
+    /// to the outlined function, rather than using regular allocas.
     LLVM_ABI
     CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT = nullptr,
                   bool AggregateArgs = false, BlockFrequencyInfo *BFI = 
nullptr,
@@ -157,7 +182,10 @@ class CodeExtractorAnalysisCache {
                   AssumptionCache *AC = nullptr, bool AllowVarArgs = false,
                   bool AllowAlloca = false,
                   BasicBlock *AllocationBlock = nullptr,
-                  std::string Suffix = "", bool ArgsInZeroAddressSpace = 
false);
+                  std::string Suffix = "", bool ArgsInZeroAddressSpace = false,
+                  CustomArgAllocatorCBTy *CustomArgAllocatorCB = nullptr,
+                  BasicBlock *DeallocationBlock = nullptr,
+                  CustomArgDeallocatorCBTy *CustomArgDeallocatorCB = nullptr);
 
     /// Perform the extraction, returning the new function.
     ///
@@ -177,8 +205,9 @@ class CodeExtractorAnalysisCache {
     /// newly outlined function.
     /// \returns zero when called on a CodeExtractor instance where isEligible
     /// returns false.
-    LLVM_ABI Function *extractCodeRegion(const CodeExtractorAnalysisCache 
&CEAC,
-                                         ValueSet &Inputs, ValueSet &Outputs);
+    LLVM_ABI Function *
+    extractCodeRegion(const CodeExtractorAnalysisCache &CEAC, ValueSet &Inputs,
+                      ValueSet &Outputs);
 
     /// Verify that assumption cache isn't stale after a region is extracted.
     /// Returns true when verifier finds errors. AssumptionCache is passed as
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp 
b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 2e8fb5efb7743..a913958c0de9a 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -268,6 +268,38 @@ computeOpenMPScheduleType(ScheduleKind ClauseKind, bool 
HasChunks,
   return Result;
 }
 
+/// Given a function, if it represents the entry point of a target kernel, this
+/// returns the execution mode flags associated to that kernel.
+static std::optional<omp::OMPTgtExecModeFlags>
+getTargetKernelExecMode(Function &Kernel) {
+  CallInst *TargetInitCall = nullptr;
+  for (Instruction &Inst : Kernel.getEntryBlock()) {
+    if (auto *Call = dyn_cast<CallInst>(&Inst)) {
+      if (Call->getCalledFunction()->getName() == "__kmpc_target_init") {
+        TargetInitCall = Call;
+        break;
+      }
+    }
+  }
+
+  if (!TargetInitCall)
+    return std::nullopt;
+
+  // Get the kernel mode information from the global variable associated to the
+  // first argument to the call to __kmpc_target_init. Refer to
+  // createTargetInit() to see how this is initialized.
+  Value *InitOperand = TargetInitCall->getArgOperand(0);
+  GlobalVariable *KernelEnv = nullptr;
+  if (auto *Cast = dyn_cast<ConstantExpr>(InitOperand))
+    KernelEnv = cast<GlobalVariable>(Cast->getOperand(0));
+  else
+    KernelEnv = cast<GlobalVariable>(InitOperand);
+  auto *KernelEnvInit = cast<ConstantStruct>(KernelEnv->getInitializer());
+  auto *ConfigEnv = cast<ConstantStruct>(KernelEnvInit->getOperand(0));
+  auto *KernelMode = cast<ConstantInt>(ConfigEnv->getOperand(2));
+  return static_cast<OMPTgtExecModeFlags>(KernelMode->getZExtValue());
+}
+
 /// Make \p Source branch to \p Target.
 ///
 /// Handles two situations:
@@ -702,15 +734,19 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
     // CodeExtractor generates correct code for extracted functions
     // which are used by OpenMP runtime.
     bool ArgsInZeroAddressSpace = Config.isTargetDevice();
-    CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
-                            /* AggregateArgs */ true,
-                            /* BlockFrequencyInfo */ nullptr,
-                            /* BranchProbabilityInfo */ nullptr,
-                            /* AssumptionCache */ nullptr,
-                            /* AllowVarArgs */ true,
-                            /* AllowAlloca */ true,
-                            /* AllocaBlock*/ OI.OuterAllocaBB,
-                            /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
+    CodeExtractor Extractor(
+        Blocks, /* DominatorTree */ nullptr,
+        /* AggregateArgs */ true,
+        /* BlockFrequencyInfo */ nullptr,
+        /* BranchProbabilityInfo */ nullptr,
+        /* AssumptionCache */ nullptr,
+        /* AllowVarArgs */ true,
+        /* AllowAlloca */ true,
+        /* AllocaBlock*/ OI.OuterAllocaBB,
+        /* Suffix */ ".omp_par", ArgsInZeroAddressSpace,
+        OI.CustomArgAllocatorCB ? &OI.CustomArgAllocatorCB : nullptr,
+        /* DeallocationBlock */ OI.ExitBB,
+        OI.CustomArgDeallocatorCB ? &OI.CustomArgDeallocatorCB : nullptr);
 
     LLVM_DEBUG(dbgs() << "Before     outlining: " << *OuterFn << "\n");
     LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
@@ -1614,6 +1650,50 @@ OpenMPIRBuilder::InsertPointOrErrorTy 
OpenMPIRBuilder::createParallel(
                              IfCondition, NumThreads, PrivTID, PrivTIDAddr,
                              ThreadID, ToBeDeletedVec);
     };
+
+    std::optional<omp::OMPTgtExecModeFlags> ExecMode =
+        getTargetKernelExecMode(*OuterFn);
+
+    // If OuterFn is not a Generic kernel, skip custom allocation. This causes
+    // the CodeExtractor to follow its default behavior. Otherwise, we need to
+    // use device shared memory to allocate argument structures.
+    if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC) {
+      OI.CustomArgAllocatorCB = [this,
+                                 EntryBB](BasicBlock *, BasicBlock::iterator,
+                                          Type *ArgTy, const Twine &Name) {
+        // Instead of using the insertion point provided by the CodeExtractor,
+        // here we need to use the block that eventually calls the outlined
+        // function for the `parallel` construct.
+        //
+        // The reason is that the explicit deallocation call will be inserted
+        // within the outlined function, whereas the alloca insertion point
+        // might actually be located somewhere else in the caller. This becomes
+        // a problem when e.g. `parallel` is inside of a `distribute` 
construct,
+        // because the deallocation would be executed multiple times and the
+        // allocation just once (outside of the loop).
+        //
+        // TODO: Ideally, we'd want to do the allocation and deallocation
+        // outside of the `parallel` outlined function, hence using here the
+        // insertion point provided by the CodeExtractor. We can't do this at
+        // the moment because there is currently no way of passing an eligible
+        // insertion point for the explicit deallocation to the CodeExtractor,
+        // as that block is created (at least when nested inside of
+        // `distribute`) sometime after createParallel() completed, so it can't
+        // be stored in the OutlineInfo structure here.
+        //
+        // The current approach results in an explicit allocation and
+        // deallocation pair for each `distribute` loop iteration in that case,
+        // which is suboptimal.
+        return createOMPAllocShared(
+            InsertPointTy(EntryBB, EntryBB->getFirstInsertionPt()), ArgTy,
+            Name);
+      };
+      OI.CustomArgDeallocatorCB =
+          [this](BasicBlock *BB, BasicBlock::iterator AllocIP, Value *Arg,
+                 Type *ArgTy) -> Instruction * {
+        return createOMPFreeShared(InsertPointTy(BB, AllocIP), Arg, ArgTy);
+      };
+    }
   } else {
     // Generate OpenMP host runtime call
     OI.PostOutlineCB = [=, ToBeDeletedVec =
diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp 
b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index 7a9dd37b72205..a4943150fdffc 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -25,7 +25,6 @@
 #include "llvm/Analysis/BranchProbabilityInfo.h"
 #include "llvm/IR/Argument.h"
 #include "llvm/IR/Attributes.h"
-#include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Constant.h"
 #include "llvm/IR/Constants.h"
@@ -265,12 +264,18 @@ CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, 
DominatorTree *DT,
                              BranchProbabilityInfo *BPI, AssumptionCache *AC,
                              bool AllowVarArgs, bool AllowAlloca,
                              BasicBlock *AllocationBlock, std::string Suffix,
-                             bool ArgsInZeroAddressSpace)
+                             bool ArgsInZeroAddressSpace,
+                             CustomArgAllocatorCBTy *CustomArgAllocatorCB,
+                             BasicBlock *DeallocationBlock,
+                             CustomArgDeallocatorCBTy *CustomArgDeallocatorCB)
     : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
       BPI(BPI), AC(AC), AllocationBlock(AllocationBlock),
       AllowVarArgs(AllowVarArgs),
       Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)),
-      Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace) {}
+      Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace),
+      CustomArgAllocatorCB(CustomArgAllocatorCB),
+      DeallocationBlock(DeallocationBlock),
+      CustomArgDeallocatorCB(CustomArgDeallocatorCB) {}
 
 /// definedInRegion - Return true if the specified value is defined in the
 /// extracted region.
@@ -1852,24 +1857,38 @@ CallInst *CodeExtractor::emitReplacerCall(
     if (StructValues.contains(output))
       continue;
 
-    AllocaInst *alloca = new AllocaInst(
-        output->getType(), DL.getAllocaAddrSpace(), nullptr,
-        output->getName() + ".loc", AllocaBlock->getFirstInsertionPt());
-    params.push_back(alloca);
-    ReloadOutputs.push_back(alloca);
+    Value *OutAlloc;
+    if (CustomArgAllocatorCB)
+      OutAlloc = (*CustomArgAllocatorCB)(
+          AllocaBlock, AllocaBlock->getFirstInsertionPt(), output->getType(),
+          output->getName() + ".loc");
+    else
+      OutAlloc = new AllocaInst(output->getType(), DL.getAllocaAddrSpace(),
+                                nullptr, output->getName() + ".loc",
+                                AllocaBlock->getFirstInsertionPt());
+
+    params.push_back(OutAlloc);
+    ReloadOutputs.push_back(OutAlloc);
   }
 
-  AllocaInst *Struct = nullptr;
+  Instruction *Struct = nullptr;
   if (!StructValues.empty()) {
-    Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr,
-                            "structArg", AllocaBlock->getFirstInsertionPt());
-    if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
-      auto *StructSpaceCast = new AddrSpaceCastInst(
-          Struct, PointerType ::get(Context, 0), "structArg.ascast");
-      StructSpaceCast->insertAfter(Struct->getIterator());
-      params.push_back(StructSpaceCast);
-    } else {
+    BasicBlock::iterator StructArgIP = AllocaBlock->getFirstInsertionPt();
+    if (CustomArgAllocatorCB) {
+      Struct = (*CustomArgAllocatorCB)(AllocaBlock, StructArgIP, StructArgTy,
+                                       "structArg");
       params.push_back(Struct);
+    } else {
+      Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr,
+                              "structArg", StructArgIP);
+      if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
+        auto *StructSpaceCast = new AddrSpaceCastInst(
+            Struct, PointerType ::get(Context, 0), "structArg.ascast");
+        StructSpaceCast->insertAfter(Struct->getIterator());
+        params.push_back(StructSpaceCast);
+      } else {
+        params.push_back(Struct);
+      }
     }
 
     unsigned AggIdx = 0;
@@ -2013,6 +2032,26 @@ CallInst *CodeExtractor::emitReplacerCall(
   insertLifetimeMarkersSurroundingCall(oldFunction->getParent(), 
LifetimesStart,
                                        {}, call);
 
+  // Deallocate variables that used a custom allocator.
+  if (CustomArgAllocatorCB && CustomArgDeallocatorCB) {
+    BasicBlock *DeallocBlock = codeReplacer;
+    BasicBlock::iterator DeallocIP = codeReplacer->end();
+    if (DeallocationBlock) {
+      DeallocBlock = DeallocationBlock;
+      DeallocIP = DeallocationBlock->getFirstInsertionPt();
+    }
+
+    int Index = 0;
+    for (Value *Output : outputs) {
+      if (!StructValues.contains(Output))
+        (*CustomArgDeallocatorCB)(DeallocBlock, DeallocIP,
+                                  ReloadOutputs[Index++], Output->getType());
+    }
+
+    if (Struct)
+      (*CustomArgDeallocatorCB)(DeallocBlock, DeallocIP, Struct, StructArgTy);
+  }
+
   return call;
 }
 
diff --git a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir 
b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
index 60c6fa4dd8f1e..504e39c96f008 100644
--- a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
@@ -56,8 +56,6 @@ module attributes {dlti.dl_spec = 
#dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
 // CHECK-SAME: ptr %[[TMP:.*]], ptr %[[TMP0:.*]]) #{{[0-9]+}} {
 // CHECK:         %[[TMP1:.*]] = alloca [1 x ptr], align 8, addrspace(5)
 // CHECK:         %[[TMP2:.*]] = addrspacecast ptr addrspace(5) %[[TMP1]] to 
ptr
-// CHECK:         %[[STRUCTARG:.*]] = alloca { ptr }, align 8, addrspace(5)
-// CHECK:         %[[STRUCTARG_ASCAST:.*]] = addrspacecast ptr addrspace(5) 
%[[STRUCTARG]] to ptr
 // CHECK:         %[[TMP3:.*]] = alloca ptr, align 8, addrspace(5)
 // CHECK:         %[[TMP4:.*]] = addrspacecast ptr addrspace(5) %[[TMP3]] to 
ptr
 // CHECK:         store ptr %[[TMP0]], ptr %[[TMP4]], align 8
@@ -65,12 +63,14 @@ module attributes {dlti.dl_spec = 
#dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
 // CHECK:         %[[EXEC_USER_CODE:.*]] = icmp eq i32 %[[TMP5]], -1
 // CHECK:         br i1 %[[EXEC_USER_CODE]], label %[[USER_CODE_ENTRY:.*]], 
label %[[WORKER_EXIT:.*]]
 // CHECK:         %[[TMP6:.*]] = load ptr, ptr %[[TMP4]], align 8
+// CHECK:         %[[STRUCTARG:.*]] = call align 8 ptr 
@__kmpc_alloc_shared(i64 8)
 // CHECK:         %[[OMP_GLOBAL_THREAD_NUM:.*]] = call i32 
@__kmpc_global_thread_num(ptr addrspacecast (ptr addrspace(1) @[[GLOB1:[0-9]+]] 
to ptr))
-// CHECK:         %[[GEP_:.*]] = getelementptr { ptr }, ptr addrspace(5) 
%[[STRUCTARG]], i32 0, i32 0
-// CHECK:         store ptr %[[TMP6]], ptr addrspace(5) %[[GEP_]], align 8
+// CHECK:         %[[GEP_:.*]] = getelementptr { ptr }, ptr %[[STRUCTARG]], 
i32 0, i32 0
+// CHECK:         store ptr %[[TMP6]], ptr %[[GEP_]], align 8
 // CHECK:         %[[TMP7:.*]] = getelementptr inbounds [1 x ptr], ptr 
%[[TMP2]], i64 0, i64 0
-// CHECK:         store ptr %[[STRUCTARG_ASCAST]], ptr %[[TMP7]], align 8
+// CHECK:         store ptr %[[STRUCTARG]], ptr %[[TMP7]], align 8
 // CHECK:         call void @__kmpc_parallel_51(ptr addrspacecast (ptr 
addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, 
i32 -1, ptr @[[FUNC1:.*]], ptr null, ptr %[[TMP2]], i64 1)
+// CHECK:         call void @__kmpc_free_shared(ptr %[[STRUCTARG]], i64 8)
 // CHECK:         call void @__kmpc_target_deinit()
 
 // CHECK: define internal void @[[FUNC1]](

``````````

</details>


https://github.com/llvm/llvm-project/pull/150925
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to