================ @@ -0,0 +1,446 @@ +//===- LowerWorkshare.cpp - special cases for bufferization -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the lowering of omp.workshare to other omp constructs. +// +// This pass is tasked with parallelizing the loops nested in +// workshare.loop_wrapper while both the Fortran to mlir lowering and the hlfir +// to fir lowering pipelines are responsible for emitting the +// workshare.loop_wrapper ops where appropriate according to the +// `shouldUseWorkshareLowering` function. +// +//===----------------------------------------------------------------------===// + +#include <flang/Optimizer/Builder/FIRBuilder.h> +#include <flang/Optimizer/Dialect/FIROps.h> +#include <flang/Optimizer/Dialect/FIRType.h> +#include <flang/Optimizer/HLFIR/HLFIROps.h> +#include <flang/Optimizer/OpenMP/Passes.h> +#include <llvm/ADT/BreadthFirstIterator.h> +#include <llvm/ADT/STLExtras.h> +#include <llvm/ADT/SmallVectorExtras.h> +#include <llvm/ADT/iterator_range.h> +#include <llvm/Support/ErrorHandling.h> +#include <mlir/Dialect/Arith/IR/Arith.h> +#include <mlir/Dialect/LLVMIR/LLVMTypes.h> +#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h> +#include <mlir/Dialect/OpenMP/OpenMPDialect.h> +#include <mlir/Dialect/SCF/IR/SCF.h> +#include <mlir/IR/BuiltinOps.h> +#include <mlir/IR/IRMapping.h> +#include <mlir/IR/OpDefinition.h> +#include <mlir/IR/PatternMatch.h> +#include <mlir/IR/Visitors.h> +#include <mlir/Interfaces/SideEffectInterfaces.h> +#include <mlir/Support/LLVM.h> +#include <mlir/Transforms/GreedyPatternRewriteDriver.h> + +#include <variant> + +namespace flangomp { +#define GEN_PASS_DEF_LOWERWORKSHARE +#include "flang/Optimizer/OpenMP/Passes.h.inc" +} // namespace flangomp + +#define DEBUG_TYPE "lower-workshare" + +using namespace mlir; + +namespace flangomp { + +// Checks for nesting pattern below as we need to avoid sharing the work of +// statements which are nested in some constructs such as omp.critical or +// another omp.parallel. +// +// omp.workshare { // `wsOp` +// ... +// omp.T { // `parent` +// ... +// `op` +// +template <typename T> +static bool isNestedIn(omp::WorkshareOp wsOp, Operation *op) { + T parent = op->getParentOfType<T>(); + if (!parent) + return false; + return wsOp->isProperAncestor(parent); +} + +bool shouldUseWorkshareLowering(Operation *op) { + auto parentWorkshare = op->getParentOfType<omp::WorkshareOp>(); + + if (!parentWorkshare) + return false; + + if (isNestedIn<omp::CriticalOp>(parentWorkshare, op)) + return false; + + // 2.8.3 workshare Construct + // For a parallel construct, the construct is a unit of work with respect to + // the workshare construct. The statements contained in the parallel construct + // are executed by a new thread team. + if (isNestedIn<omp::ParallelOp>(parentWorkshare, op)) + return false; + + // 2.8.2 single Construct + // Binding The binding thread set for a single region is the current team. A + // single region binds to the innermost enclosing parallel region. + // Description Only one of the encountering threads will execute the + // structured block associated with the single construct. + if (isNestedIn<omp::SingleOp>(parentWorkshare, op)) + return false; + + return true; +} + +} // namespace flangomp + +namespace { + +struct SingleRegion { + Block::iterator begin, end; +}; + +static bool mustParallelizeOp(Operation *op) { + return op + ->walk([&](Operation *nested) { + // We need to be careful not to pick up workshare.loop_wrapper in nested + // omp.parallel{omp.workshare} regions, i.e. make sure that `nested` + // binds to the workshare region we are currently handling. + // + // For example: + // + // omp.parallel { + // omp.workshare { // currently handling this + // omp.parallel { + // omp.workshare { // nested workshare + // omp.workshare.loop_wrapper {} + // + // Therefore, we skip if we encounter a nested omp.workshare. + if (isa<omp::WorkshareOp>(op)) + return WalkResult::skip(); + if (isa<omp::WorkshareLoopWrapperOp>(op)) + return WalkResult::interrupt(); + return WalkResult::advance(); + }) + .wasInterrupted(); +} + +static bool isSafeToParallelize(Operation *op) { + return isa<hlfir::DeclareOp>(op) || isa<fir::DeclareOp>(op) || + isMemoryEffectFree(op); +} + +/// Simple shallow copies suffice for our purposes in this pass, so we implement +/// this simpler alternative to the full fledged `createCopyFunc` in the +/// frontend +static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType, + fir::FirOpBuilder builder) { + mlir::ModuleOp module = builder.getModule(); + auto rt = cast<fir::ReferenceType>(varType); + mlir::Type eleTy = rt.getEleTy(); + std::string copyFuncName = + fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy"); + + if (auto decl = module.lookupSymbol<mlir::func::FuncOp>(copyFuncName)) + return decl; + // create function + mlir::OpBuilder::InsertionGuard guard(builder); + mlir::OpBuilder modBuilder(module.getBodyRegion()); + llvm::SmallVector<mlir::Type> argsTy = {varType, varType}; + auto funcType = mlir::FunctionType::get(builder.getContext(), argsTy, {}); + mlir::func::FuncOp funcOp = + modBuilder.create<mlir::func::FuncOp>(loc, copyFuncName, funcType); + funcOp.setVisibility(mlir::SymbolTable::Visibility::Private); + builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy, + {loc, loc}); + builder.setInsertionPointToStart(&funcOp.getRegion().back()); + + Value loaded = builder.create<fir::LoadOp>(loc, funcOp.getArgument(0)); + builder.create<fir::StoreOp>(loc, loaded, funcOp.getArgument(1)); + + builder.create<mlir::func::ReturnOp>(loc); + return funcOp; +} + +static bool isUserOutsideSR(Operation *user, Operation *parentOp, + SingleRegion sr) { + while (user->getParentOp() != parentOp) + user = user->getParentOp(); + return sr.begin->getBlock() != user->getBlock() || + !(user->isBeforeInBlock(&*sr.end) && sr.begin->isBeforeInBlock(user)); +} + +static bool isTransitivelyUsedOutside(Value v, SingleRegion sr) { + Block *srBlock = sr.begin->getBlock(); + Operation *parentOp = srBlock->getParentOp(); + + for (auto &use : v.getUses()) { + Operation *user = use.getOwner(); + if (isUserOutsideSR(user, parentOp, sr)) + return true; + + // Results of nested users cannot be used outside of the SR + if (user->getBlock() != srBlock) + continue; + + // A non-safe to parallelize operation will be handled separately + if (!isSafeToParallelize(user)) + continue; + + for (auto res : user->getResults()) + if (isTransitivelyUsedOutside(res, sr)) + return true; + } + return false; +} + +/// We clone pure operations in both the parallel and single blocks. this +/// functions cleans them up if they end up with no uses +static void cleanupBlock(Block *block) { + for (Operation &op : llvm::make_early_inc_range( + llvm::make_range(block->rbegin(), block->rend()))) + if (isOpTriviallyDead(&op)) + op.erase(); +} + +static void parallelizeRegion(Region &sourceRegion, Region &targetRegion, + IRMapping &rootMapping, Location loc, + mlir::DominanceInfo &di) { + OpBuilder rootBuilder(sourceRegion.getContext()); + ModuleOp m = sourceRegion.getParentOfType<ModuleOp>(); + OpBuilder copyFuncBuilder(m.getBodyRegion()); + fir::FirOpBuilder firCopyFuncBuilder(copyFuncBuilder, m); + + auto mapReloadedValue = + [&](Value v, OpBuilder allocaBuilder, OpBuilder singleBuilder, + OpBuilder parallelBuilder, IRMapping singleMapping) -> Value { + if (auto reloaded = rootMapping.lookupOrNull(v)) + return nullptr; + Type ty = v.getType(); + Value alloc = allocaBuilder.create<fir::AllocaOp>(loc, ty); + singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc); + Value reloaded = parallelBuilder.create<fir::LoadOp>(loc, ty, alloc); + rootMapping.map(v, reloaded); + return alloc; + }; + + auto moveToSingle = [&](SingleRegion sr, OpBuilder allocaBuilder, + OpBuilder singleBuilder, + OpBuilder parallelBuilder) -> SmallVector<Value> { + IRMapping singleMapping = rootMapping; + SmallVector<Value> copyPrivate; + + for (Operation &op : llvm::make_range(sr.begin, sr.end)) { + if (isSafeToParallelize(&op)) { + singleBuilder.clone(op, singleMapping); + parallelBuilder.clone(op, rootMapping); + } else if (auto alloca = dyn_cast<fir::AllocaOp>(&op)) { + auto hoisted = + cast<fir::AllocaOp>(allocaBuilder.clone(*alloca, singleMapping)); + rootMapping.map(&*alloca, &*hoisted); + rootMapping.map(alloca.getResult(), hoisted.getResult()); + copyPrivate.push_back(hoisted); + } else { + singleBuilder.clone(op, singleMapping); + // Prepare reloaded values for results of operations that cannot be + // safely parallelized and which are used after the region `sr` + for (auto res : op.getResults()) { + if (isTransitivelyUsedOutside(res, sr)) { + auto alloc = mapReloadedValue(res, allocaBuilder, singleBuilder, + parallelBuilder, singleMapping); + if (alloc) + copyPrivate.push_back(alloc); + } + } + } + } + singleBuilder.create<omp::TerminatorOp>(loc); + return copyPrivate; + }; + + for (Block &block : sourceRegion) { + Block *targetBlock = rootBuilder.createBlock( + &targetRegion, {}, block.getArgumentTypes(), + llvm::map_to_vector(block.getArguments(), + [](BlockArgument arg) { return arg.getLoc(); })); + rootMapping.map(&block, targetBlock); + rootMapping.map(block.getArguments(), targetBlock->getArguments()); + } + + auto handleOneBlock = [&](Block &block) { + Block &targetBlock = *rootMapping.lookup(&block); + rootBuilder.setInsertionPointToStart(&targetBlock); + Operation *terminator = block.getTerminator(); + SmallVector<std::variant<SingleRegion, Operation *>> regions; + + auto it = block.begin(); + auto getOneRegion = [&]() { + if (&*it == terminator) + return false; + if (mustParallelizeOp(&*it)) { + regions.push_back(&*it); + it++; + return true; + } + SingleRegion sr; + sr.begin = it; + while (&*it != terminator && !mustParallelizeOp(&*it)) + it++; + sr.end = it; + assert(sr.begin != sr.end); + regions.push_back(sr); + return true; + }; + while (getOneRegion()) + ; + + for (auto [i, opOrSingle] : llvm::enumerate(regions)) { + bool isLast = i + 1 == regions.size(); + if (std::holds_alternative<SingleRegion>(opOrSingle)) { + OpBuilder singleBuilder(sourceRegion.getContext()); + Block *singleBlock = new Block(); ---------------- tblah wrote:
I'm not sure about this. Everywhere else in flang we use `OpBuilder::createBlock`. I think the only difference would be that builder listeners would be notified, and we don't have any here. But this could be surprising when the code is changed. But I can see from the implementation of `createBlock` that this isn't incorrect so it is okay with me if you have a good reason? https://github.com/llvm/llvm-project/pull/101446 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits