================
@@ -391,29 +414,236 @@ class MapInfoFinalizationPass
/// of the base address index.
void adjustMemberIndices(
llvm::SmallVectorImpl<llvm::SmallVector<int64_t>> &memberIndices,
- size_t memberIndex) {
- llvm::SmallVector<int64_t> baseAddrIndex = memberIndices[memberIndex];
+ ParentAndPlacement parentAndPlacement) {
+ llvm::SmallVector<int64_t> baseAddrIndex =
+ memberIndices[parentAndPlacement.index];
+ auto &expansionIndices = expandedBaseAddr[parentAndPlacement.parent];
// If we find another member that is "derived/a member of" the descriptor
// that is not the descriptor itself, we must insert a 0 for the new base
// address we have just added for the descriptor into the list at the
// appropriate position to maintain correctness of the positional/index
data
// for that member.
- for (llvm::SmallVector<int64_t> &member : memberIndices)
+ for (auto [i, member] : llvm::enumerate(memberIndices)) {
+ if (std::find(expansionIndices.begin(), expansionIndices.end(), i) !=
+ expansionIndices.end())
+ if (member.size() == baseAddrIndex.size() + 1 &&
+ member[baseAddrIndex.size()] == 0)
+ continue;
+
if (member.size() > baseAddrIndex.size() &&
std::equal(baseAddrIndex.begin(), baseAddrIndex.end(),
member.begin()))
member.insert(std::next(member.begin(), baseAddrIndex.size()), 0);
+ }
// Add the base address index to the main base address member data
baseAddrIndex.push_back(0);
- // Insert our newly created baseAddrIndex into the larger list of indices
at
- // the correct location.
- memberIndices.insert(std::next(memberIndices.begin(), memberIndex + 1),
+ uint64_t newIdxInsert = parentAndPlacement.index + 1;
+ expansionIndices.push_back(newIdxInsert);
+
+ // Insert our newly created baseAddrIndex into the larger list of
+ // indices at the correct location.
+ memberIndices.insert(std::next(memberIndices.begin(), newIdxInsert),
baseAddrIndex);
}
+ // This function takes a Map clause owning target operation (e.g. TargetOp or
+ // TargetDataOp) and a lambda function, the lambda function is invoked on the
+ // various map clause ranges of the target operation that was passed in (e.g.
+ // the use_device_ptr/addr and regular maps count as map clause ranges for
the
+ // purpose of this function) with the intent of inserting new maps into the
+ // range in a manner that is consistent with the target that was passed in.
+ //
+ // The lambda function should take 3 parameters a range that represents the
+ // map range, an operation representing the target and an unsigned integer
+ // representing the start index for the map range in terms of the targets
+ // block argument list. The insertion behaviour of the function is left to
the
+ // lambda.
+ void
+ insertIntoMapClauseInterface(mlir::Operation *target,
+ std::function<void(mlir::MutableOperandRange &,
+ mlir::Operation *, unsigned)>
+ addOperands) {
+ auto argIface =
+ llvm::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(target);
+
+ if (auto mapClauseOwner =
+ llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(target)) {
+ mlir::MutableOperandRange mapVarsArr =
mapClauseOwner.getMapVarsMutable();
+ unsigned blockArgInsertIndex =
+ argIface
+ ? argIface.getMapBlockArgsStart() + argIface.numMapBlockArgs()
+ : 0;
+ addOperands(mapVarsArr,
+ llvm::dyn_cast_if_present<mlir::omp::TargetOp>(target),
+ blockArgInsertIndex);
+ }
+
+ if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(target)) {
+ mlir::MutableOperandRange useDevAddrMutableOpRange =
+ targetDataOp.getUseDeviceAddrVarsMutable();
+ addOperands(useDevAddrMutableOpRange, target,
+ argIface.getUseDeviceAddrBlockArgsStart() +
+ argIface.numUseDeviceAddrBlockArgs());
+
+ mlir::MutableOperandRange useDevPtrMutableOpRange =
+ targetDataOp.getUseDevicePtrVarsMutable();
+ addOperands(useDevPtrMutableOpRange, target,
+ argIface.getUseDevicePtrBlockArgsStart() +
+ argIface.numUseDevicePtrBlockArgs());
+ } else if (auto targetOp = llvm::dyn_cast<mlir::omp::TargetOp>(target)) {
+ mlir::MutableOperandRange hasDevAddrMutableOpRange =
+ targetOp.getHasDeviceAddrVarsMutable();
+ addOperands(hasDevAddrMutableOpRange, target,
+ argIface.getHasDeviceAddrBlockArgsStart() +
+ argIface.numHasDeviceAddrBlockArgs());
+ }
+ }
+
+ // This function aims to insert new maps derived from existing maps into the
+ // corresponding clause list, interlinking it correctly with block arguments
+ // where required.
+ void addDerivedMemberToTarget(
+ mlir::omp::MapInfoOp owner, mlir::omp::MapInfoOp derived,
+ llvm::SmallVectorImpl<ParentAndPlacement> &mapMemberUsers,
+ fir::FirOpBuilder &builder, mlir::Operation *target) {
+ auto addOperands = [&](mlir::MutableOperandRange &mapVarsArr,
+ mlir::Operation *directiveOp,
+ unsigned blockArgInsertIndex = 0) {
+ // Check we're inserting into the correct MapInfoOp list.
+ if (!llvm::is_contained(mapVarsArr.getAsOperandRange(),
+ mapMemberUsers.empty()
+ ? owner.getResult()
+ : mapMemberUsers[0].parent.getResult()))
+ return;
+
+ // Check we're not inserting a duplicate map.
+ if (llvm::is_contained(mapVarsArr.getAsOperandRange(),
+ derived.getResult()))
+ return;
+
+ llvm::SmallVector<mlir::Value> newMapOps;
+ newMapOps.reserve(mapVarsArr.size());
+ llvm::copy(mapVarsArr.getAsOperandRange(),
std::back_inserter(newMapOps));
+
+ newMapOps.push_back(derived);
+ if (directiveOp) {
+ directiveOp->getRegion(0).insertArgument(
+ blockArgInsertIndex, derived.getType(), derived.getLoc());
+ blockArgInsertIndex++;
+ }
+
+ mapVarsArr.assign(newMapOps);
----------------
agozillon wrote:
Thank you. I think I tried this in the original iteration of this a year or so
ago and not quite managing to get it working and this variant has just stuck
around!
https://github.com/llvm/llvm-project/pull/177715
_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits