https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/182856
>From d50d39e92b8d29e0747af579fb38a103b758c803 Mon Sep 17 00:00:00 2001 From: Sergio Afonso <[email protected]> Date: Thu, 5 Feb 2026 14:10:18 +0000 Subject: [PATCH 1/2] [MLIR][OpenMP] Unify device shared memory logic This patch creates a utils library for the OpenMP dialect with functions used by MLIR to LLVM IR translation as well as the stack-to-shared pass to determine which allocations must use local stack memory or device shared memory. --- .../include/mlir/Dialect/OpenMP/Utils/Utils.h | 53 +++++++++ mlir/lib/Dialect/OpenMP/CMakeLists.txt | 1 + .../Dialect/OpenMP/Transforms/CMakeLists.txt | 1 + .../OpenMP/Transforms/StackToShared.cpp | 98 +++-------------- mlir/lib/Dialect/OpenMP/Utils/CMakeLists.txt | 13 +++ mlir/lib/Dialect/OpenMP/Utils/Utils.cpp | 104 ++++++++++++++++++ .../LLVMIR/Dialect/OpenMP/CMakeLists.txt | 1 + .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 103 ++--------------- 8 files changed, 196 insertions(+), 178 deletions(-) create mode 100644 mlir/include/mlir/Dialect/OpenMP/Utils/Utils.h create mode 100644 mlir/lib/Dialect/OpenMP/Utils/CMakeLists.txt create mode 100644 mlir/lib/Dialect/OpenMP/Utils/Utils.cpp diff --git a/mlir/include/mlir/Dialect/OpenMP/Utils/Utils.h b/mlir/include/mlir/Dialect/OpenMP/Utils/Utils.h new file mode 100644 index 0000000000000..ce625c7170efe --- /dev/null +++ b/mlir/include/mlir/Dialect/OpenMP/Utils/Utils.h @@ -0,0 +1,53 @@ +//===- Utils.h - OpenMP dialect utilities -----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines prototypes for various OpenMP utilities. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_OPENMP_UTILS_UTILS_H_ +#define MLIR_DIALECT_OPENMP_UTILS_UTILS_H_ + +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" + +namespace mlir { +namespace omp { + +/// Check whether the value representing an allocation, assumed to have been +/// defined in a shared device context, is used in a manner that would require +/// device shared memory for correctness. +/// +/// When a use takes place inside an omp.parallel region and it's not as a +/// private clause argument, or when it is a reduction argument passed to +/// omp.parallel or a function call argument, then the defining allocation is +/// eligible for replacement with shared memory. +/// +/// \see mlir::omp::opInSharedDeviceContext(). +bool allocaUsesRequireSharedMem(Value alloc); + +/// Check whether the given operation is located in a context where an +/// allocation to be used by multiple threads in a parallel region would have to +/// be placed in device shared memory to be accessible. +/// +/// That means that it is inside of a target device module, it is a non-SPMD +/// target region, is inside of one or it's located in a device function, and it +/// is not not inside of a parallel region. +/// +/// This represents a necessary but not sufficient set of conditions to use +/// device shared memory in place of regular allocas. For some variables, the +/// associated OpenMP construct or their uses might also need to be taken into +/// account. +/// +/// \see mlir::omp::allocaUsesRequireSharedMem(). +bool opInSharedDeviceContext(Operation &op); + +} // namespace omp +} // namespace mlir + +#endif // MLIR_DIALECT_OPENMP_UTILS_UTILS_H_ diff --git a/mlir/lib/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/CMakeLists.txt index 9f57627c321fb..31167e6af908b 100644 --- a/mlir/lib/Dialect/OpenMP/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenMP/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt index 8f308226394c6..998e45ebe2539 100644 --- a/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt @@ -19,6 +19,7 @@ add_mlir_dialect_library(MLIROpenMPTransforms MLIRLLVMDialect MLIROpenACCMPCommon MLIROpenMPDialect + MLIROpenMPUtils MLIRPass MLIRSupport MLIRTransforms diff --git a/mlir/lib/Dialect/OpenMP/Transforms/StackToShared.cpp b/mlir/lib/Dialect/OpenMP/Transforms/StackToShared.cpp index 0edccf53a2031..6b54f9c14013c 100644 --- a/mlir/lib/Dialect/OpenMP/Transforms/StackToShared.cpp +++ b/mlir/lib/Dialect/OpenMP/Transforms/StackToShared.cpp @@ -15,7 +15,9 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/OpenMP/Utils/Utils.h" #include "mlir/Pass/Pass.h" +#include "llvm/ADT/STLExtras.h" namespace mlir { namespace omp { @@ -26,94 +28,20 @@ namespace omp { using namespace mlir; -/// When a use takes place inside an omp.parallel region and it's not as a -/// private clause argument, or when it is a reduction argument passed to -/// omp.parallel or a function call argument, then the defining allocation is -/// eligible for replacement with shared memory. -static bool allocaUseRequiresDeviceSharedMem(const OpOperand &use) { - Operation *owner = use.getOwner(); - if (auto parallelOp = dyn_cast<omp::ParallelOp>(owner)) { - if (llvm::is_contained(parallelOp.getReductionVars(), use.get())) - return true; - } else if (auto callOp = dyn_cast<CallOpInterface>(owner)) { - if (llvm::is_contained(callOp.getArgOperands(), use.get())) - return true; - } - - // If it is used directly inside of a parallel region, it has to be replaced - // unless the use is a private clause. - if (owner->getParentOfType<omp::ParallelOp>()) { - if (auto argIface = dyn_cast<omp::BlockArgOpenMPOpInterface>(owner)) { - if (auto privateSyms = - cast_or_null<ArrayAttr>(owner->getAttr("private_syms"))) { - for (auto [var, sym] : - llvm::zip_equal(argIface.getPrivateVars(), privateSyms)) { - if (var != use.get()) - continue; - - auto moduleOp = owner->getParentOfType<ModuleOp>(); - auto privateOp = cast<omp::PrivateClauseOp>( - moduleOp.lookupSymbol(cast<SymbolRefAttr>(sym))); - return privateOp.getDataSharingType() != - omp::DataSharingClauseType::Private; - } - } - } - return true; - } - return false; -} - -static bool shouldReplaceAllocaWithUses(const Operation::use_range &uses) { - // Check direct uses and also follow hlfir.declare/fir.convert uses. - for (const OpOperand &use : uses) { - Operation *owner = use.getOwner(); - if (llvm::isa<LLVM::AddrSpaceCastOp, LLVM::GEPOp>(owner)) { - if (shouldReplaceAllocaWithUses(owner->getUses())) - return true; - } else if (allocaUseRequiresDeviceSharedMem(use)) { - return true; - } - } - - return false; -} - -// TODO: Refactor the logic in `shouldReplaceAllocaWithDeviceSharedMem`, -// `shouldReplaceAllocaWithUses` and `allocaUseRequiresDeviceSharedMem` to -// be reusable by the MLIR to LLVM IR translation stage, as something very -// similar is also implemented there to choose between allocas and device -// shared memory allocations when processing OpenMP reductions, mapping and -// privatization. +/// Tell whether to replace an operation representing a stack allocation with a +/// device shared memory allocation/deallocation pair based on the location of +/// the allocation and its uses. static bool shouldReplaceAllocaWithDeviceSharedMem(Operation &op) { - auto offloadIface = op.getParentOfType<omp::OffloadModuleInterface>(); - if (!offloadIface || !offloadIface.getIsTargetDevice()) - return false; - - auto targetOp = op.getParentOfType<omp::TargetOp>(); - - // It must be inside of a generic omp.target or in a target device function, - // and not inside of omp.parallel. - if (auto parallelOp = op.getParentOfType<omp::ParallelOp>()) { - if (!targetOp || targetOp->isProperAncestor(parallelOp)) - return false; - } - - if (targetOp) { - if (targetOp.getKernelExecFlags(targetOp.getInnermostCapturedOmpOp()) != - omp::TargetExecMode::generic) - return false; - } else { - auto declTargetIface = op.getParentOfType<omp::DeclareTargetInterface>(); - if (!declTargetIface || !declTargetIface.isDeclareTarget() || - declTargetIface.getDeclareTargetDeviceType() == - omp::DeclareTargetDeviceType::host) - return false; - } - - return shouldReplaceAllocaWithUses(op.getUses()); + return omp::opInSharedDeviceContext(op) && + llvm::any_of(op.getResults(), [&](Value result) { + return omp::allocaUsesRequireSharedMem(result); + }); } +/// Based on the location of the definition of the given value representing the +/// result of a device shared memory allocation, find the corresponding points +/// where its deallocation should be placed and introduce `omp.free_shared_mem` +/// ops at those points. static void insertDeviceSharedMemDeallocation(OpBuilder &builder, TypeAttr elemType, Value arraySize, diff --git a/mlir/lib/Dialect/OpenMP/Utils/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/Utils/CMakeLists.txt new file mode 100644 index 0000000000000..8fd8ba2622c68 --- /dev/null +++ b/mlir/lib/Dialect/OpenMP/Utils/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library(MLIROpenMPUtils + Utils.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMDialect + MLIROpenACCMPCommon + MLIROpenMPDialect + MLIRSupport + ) diff --git a/mlir/lib/Dialect/OpenMP/Utils/Utils.cpp b/mlir/lib/Dialect/OpenMP/Utils/Utils.cpp new file mode 100644 index 0000000000000..f5b7aa7ca2e2c --- /dev/null +++ b/mlir/lib/Dialect/OpenMP/Utils/Utils.cpp @@ -0,0 +1,104 @@ +//===- StackToShared.cpp -------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements various OpenMP dialect utilities. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenMP/Utils/Utils.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" + +using namespace mlir; + +static bool allocaUseRequiresSharedMem(const OpOperand &use) { + Operation *owner = use.getOwner(); + if (auto parallelOp = dyn_cast<omp::ParallelOp>(owner)) { + if (llvm::is_contained(parallelOp.getReductionVars(), use.get())) + return true; + } else if (auto callOp = dyn_cast<CallOpInterface>(owner)) { + if (llvm::is_contained(callOp.getArgOperands(), use.get())) + return true; + } + + // If it is used directly inside of a parallel region, it has to be replaced + // unless the use is a private clause. + if (owner->getParentOfType<omp::ParallelOp>()) { + if (auto argIface = dyn_cast<omp::BlockArgOpenMPOpInterface>(owner)) { + if (auto privateSyms = + cast_or_null<ArrayAttr>(owner->getAttr("private_syms"))) { + for (auto [var, sym] : + llvm::zip_equal(argIface.getPrivateVars(), privateSyms)) { + if (var != use.get()) + continue; + + auto moduleOp = owner->getParentOfType<ModuleOp>(); + auto privateOp = cast<omp::PrivateClauseOp>( + moduleOp.lookupSymbol(cast<SymbolRefAttr>(sym))); + return privateOp.getDataSharingType() != + omp::DataSharingClauseType::Private; + } + } + } + return true; + } + return false; +} + +bool mlir::omp::allocaUsesRequireSharedMem(Value alloc) { + for (const OpOperand &use : alloc.getUses()) { + Operation *owner = use.getOwner(); + if (isa<LLVM::AddrSpaceCastOp, LLVM::GEPOp>(owner)) { + if (llvm::any_of(owner->getResults(), [&](Value result) { + return allocaUsesRequireSharedMem(result); + })) + return true; + } else if (allocaUseRequiresSharedMem(use)) { + return true; + } + } + return false; +} + +bool mlir::omp::opInSharedDeviceContext(Operation &op) { + if (isa<omp::ParallelOp>(op)) + return false; + + auto offloadIface = op.getParentOfType<omp::OffloadModuleInterface>(); + if (!offloadIface || !offloadIface.getIsTargetDevice()) + return false; + + auto targetOp = op.getParentOfType<omp::TargetOp>(); + + // It must be inside of a generic omp.target or in a target device function, + // and not inside of omp.parallel. + if (auto parallelOp = op.getParentOfType<omp::ParallelOp>()) { + if (!targetOp || targetOp->isProperAncestor(parallelOp)) + return false; + } + + // The omp.target operation itself is considered in a shared device context in + // order to properly process its own allocation-defining entry block + // arguments. + if (!targetOp) + targetOp = dyn_cast<omp::TargetOp>(op); + + if (targetOp) { + if (targetOp.getKernelExecFlags(targetOp.getInnermostCapturedOmpOp()) != + omp::TargetExecMode::generic) + return false; + } else { + auto declTargetIface = op.getParentOfType<omp::DeclareTargetInterface>(); + if (!declTargetIface || !declTargetIface.isDeclareTarget() || + declTargetIface.getDeclareTargetDeviceType() == + omp::DeclareTargetDeviceType::host) + return false; + } + return true; +} diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/CMakeLists.txt index 0a5d7c6e22058..eb748d8b43630 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/CMakeLists.txt @@ -8,6 +8,7 @@ add_mlir_translation_library(MLIROpenMPToLLVMIRTranslation MLIRIR MLIRLLVMDialect MLIROpenMPDialect + MLIROpenMPUtils MLIRSupport MLIRTargetLLVMIRExport MLIRTransformUtils diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index dbdffd5e565c3..4acec8dcf40f7 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h" +#include "mlir/Dialect/OpenMP/Utils/Utils.h" #include "mlir/IR/Operation.h" #include "mlir/Support/LLVM.h" #include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h" @@ -1133,81 +1134,6 @@ struct DeferredStore { }; } // namespace -/// Check whether allocations for the given operation might potentially have to -/// be done in device shared memory. That means we're compiling for an -/// offloading target, the operation is neither an `omp::TargetOp` nor nested -/// inside of one, or it is and that target region represents a Generic -/// (non-SPMD) kernel. -/// -/// This represents a necessary but not sufficient set of conditions to use -/// device shared memory in place of regular allocas. For some variables, the -/// associated OpenMP construct or their uses might also need to be taken into -/// account. -static bool -mightAllocInDeviceSharedMemory(Operation &op, - const llvm::OpenMPIRBuilder &ompBuilder) { - if (!ompBuilder.Config.isTargetDevice()) - return false; - - auto targetOp = dyn_cast<omp::TargetOp>(op); - if (!targetOp) - targetOp = op.getParentOfType<omp::TargetOp>(); - - return !targetOp || - targetOp.getKernelExecFlags(targetOp.getInnermostCapturedOmpOp()) == - omp::TargetExecMode::generic; -} - -/// Check whether the entry block argument representing the private copy of a -/// variable in an OpenMP construct must be allocated in device shared memory, -/// based on what the uses of that copy are. -/// -/// This must only be called if a previous call to -/// \c mightAllocInDeviceSharedMemory has already returned \c true for the -/// operation that owns the specified block argument. -static bool mustAllocPrivateVarInDeviceSharedMemory(BlockArgument value) { - Operation *parentOp = value.getOwner()->getParentOp(); - auto moduleOp = parentOp->getParentOfType<ModuleOp>(); - for (auto *user : value.getUsers()) { - if (auto parallelOp = dyn_cast<omp::ParallelOp>(user)) { - if (llvm::is_contained(parallelOp.getReductionVars(), value)) - return true; - } else if (auto callOp = dyn_cast<CallOpInterface>(user)) { - if (llvm::is_contained(callOp.getArgOperands(), value)) - return true; - } - - if (auto parallelOp = user->getParentOfType<omp::ParallelOp>()) { - if (parentOp->isProperAncestor(parallelOp)) { - // If it is used directly inside of a parallel region, skip private - // clause uses. - bool isPrivateClauseUse = false; - if (auto argIface = dyn_cast<omp::BlockArgOpenMPOpInterface>(user)) { - if (auto privateSyms = llvm::cast_or_null<ArrayAttr>( - user->getAttr("private_syms"))) { - for (auto [var, sym] : - llvm::zip_equal(argIface.getPrivateVars(), privateSyms)) { - if (var != value) - continue; - - auto privateOp = cast<omp::PrivateClauseOp>( - moduleOp.lookupSymbol(cast<SymbolRefAttr>(sym))); - if (privateOp.getCopyRegion().empty()) { - isPrivateClauseUse = true; - break; - } - } - } - } - if (!isPrivateClauseUse) - return true; - } - } - } - - return false; -} - /// Allocate space for privatized reduction variables. /// `deferredStores` contains information to create store operations which needs /// to be inserted after all allocas @@ -1226,8 +1152,7 @@ allocReductionVars(T op, ArrayRef<BlockArgument> reductionArgs, builder.SetInsertPoint(allocaIP.getBlock()->getTerminator()); llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); - bool useDeviceSharedMem = isa<omp::TeamsOp>(*op) && - mightAllocInDeviceSharedMemory(*op, *ompBuilder); + bool useDeviceSharedMem = omp::opInSharedDeviceContext(*op); // delay creating stores until after all allocas deferredStores.reserve(op.getNumReductionVars()); @@ -1358,8 +1283,7 @@ initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs, return success(); llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); - bool useDeviceSharedMem = isa<omp::TeamsOp>(*op) && - mightAllocInDeviceSharedMemory(*op, *ompBuilder); + bool useDeviceSharedMem = omp::opInSharedDeviceContext(*op); llvm::BasicBlock *initBlock = splitBB(builder, true, "omp.reduction.init"); auto allocaIP = llvm::IRBuilderBase::InsertPoint( @@ -1606,8 +1530,7 @@ static LogicalResult createReductionsAndCleanup( reductionRegions, privateReductionVariables, moduleTranslation, builder, "omp.reduction.cleanup"); - bool useDeviceSharedMem = isa<omp::TeamsOp>(*op) && - mightAllocInDeviceSharedMemory(*op, *ompBuilder); + bool useDeviceSharedMem = omp::opInSharedDeviceContext(*op); if (useDeviceSharedMem) { for (auto [var, reductionDecl] : llvm::zip_equal(privateReductionVariables, reductionDecls)) @@ -1798,9 +1721,7 @@ allocatePrivateVars(T op, llvm::IRBuilderBase &builder, llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0); llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); - bool mightUseDeviceSharedMem = - isa<omp::TargetOp, omp::TeamsOp, omp::DistributeOp>(*op) && - mightAllocInDeviceSharedMemory(*op, *ompBuilder); + bool mightUseDeviceSharedMem = omp::opInSharedDeviceContext(*op); unsigned int allocaAS = moduleTranslation.getLLVMModule()->getDataLayout().getAllocaAddrSpace(); unsigned int defaultAS = moduleTranslation.getLLVMModule() @@ -1814,8 +1735,7 @@ allocatePrivateVars(T op, llvm::IRBuilderBase &builder, moduleTranslation.convertType(privDecl.getType()); builder.SetInsertPoint(allocaIP.getBlock()->getTerminator()); llvm::Value *llvmPrivateVar = nullptr; - if (mightUseDeviceSharedMem && - mustAllocPrivateVarInDeviceSharedMemory(blockArg)) { + if (mightUseDeviceSharedMem && omp::allocaUsesRequireSharedMem(blockArg)) { llvmPrivateVar = ompBuilder->createOMPAllocShared(builder, llvmAllocType); } else { llvmPrivateVar = builder.CreateAlloca( @@ -1952,14 +1872,11 @@ cleanupPrivateVars(T op, llvm::IRBuilderBase &builder, "`omp.private` op in"); llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); - bool mightUseDeviceSharedMem = - isa<omp::TargetOp, omp::TeamsOp, omp::DistributeOp>(*op) && - mightAllocInDeviceSharedMemory(*op, *ompBuilder); + bool mightUseDeviceSharedMem = omp::opInSharedDeviceContext(*op); for (auto [privDecl, llvmPrivVar, blockArg] : llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.llvmVars, privateVarsInfo.blockArgs)) { - if (mightUseDeviceSharedMem && - mustAllocPrivateVarInDeviceSharedMemory(blockArg)) { + if (mightUseDeviceSharedMem && omp::allocaUsesRequireSharedMem(blockArg)) { ompBuilder->createOMPFreeShared( builder, llvmPrivVar, moduleTranslation.convertType(privDecl.getType())); @@ -6577,8 +6494,8 @@ static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor( // Create the allocation for the argument. llvm::Value *v = nullptr; - if (mightAllocInDeviceSharedMemory(*targetOp, ompBuilder) && - mustAllocPrivateVarInDeviceSharedMemory(mlirArg)) { + if (omp::opInSharedDeviceContext(*targetOp) && + omp::allocaUsesRequireSharedMem(mlirArg)) { // Use the beginning of the codeGenIP rather than the usual allocation point // for shared memory allocations because otherwise these would be done prior // to the target initialization call. Also, the exit block (where the >From 2dc537e6538321c60dff443e3f16a0e8d900274c Mon Sep 17 00:00:00 2001 From: Sergio Afonso <[email protected]> Date: Tue, 3 Mar 2026 11:23:58 +0000 Subject: [PATCH 2/2] simplify private clause check --- mlir/lib/Dialect/OpenMP/Utils/Utils.cpp | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Dialect/OpenMP/Utils/Utils.cpp b/mlir/lib/Dialect/OpenMP/Utils/Utils.cpp index f5b7aa7ca2e2c..e32f47b92cf80 100644 --- a/mlir/lib/Dialect/OpenMP/Utils/Utils.cpp +++ b/mlir/lib/Dialect/OpenMP/Utils/Utils.cpp @@ -31,19 +31,16 @@ static bool allocaUseRequiresSharedMem(const OpOperand &use) { // unless the use is a private clause. if (owner->getParentOfType<omp::ParallelOp>()) { if (auto argIface = dyn_cast<omp::BlockArgOpenMPOpInterface>(owner)) { - if (auto privateSyms = - cast_or_null<ArrayAttr>(owner->getAttr("private_syms"))) { - for (auto [var, sym] : - llvm::zip_equal(argIface.getPrivateVars(), privateSyms)) { - if (var != use.get()) - continue; - - auto moduleOp = owner->getParentOfType<ModuleOp>(); - auto privateOp = cast<omp::PrivateClauseOp>( - moduleOp.lookupSymbol(cast<SymbolRefAttr>(sym))); - return privateOp.getDataSharingType() != - omp::DataSharingClauseType::Private; - } + OperandRange privateVars = argIface.getPrivateVars(); + auto it = llvm::find(privateVars, use.get()); + if (it != privateVars.end()) { + auto privateSyms = owner->getAttrOfType<ArrayAttr>("private_syms"); + size_t idx = std::distance(privateVars.begin(), it); + auto privateOp = + SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>( + owner, cast<SymbolRefAttr>(privateSyms[idx])); + return privateOp.getDataSharingType() != + omp::DataSharingClauseType::Private; } } return true; _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
