================ @@ -461,18 +461,51 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { // Replace the loop. { OpBuilder::InsertionGuard allocaGuard(rewriter); - auto loop = rewriter.create<omp::WsloopOp>( + // Create worksharing loop wrapper. + auto wsloopOp = rewriter.create<omp::WsloopOp>(parallelOp.getLoc()); + if (!reductionVariables.empty()) { + wsloopOp.setReductionsAttr( + ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols)); + wsloopOp.getReductionVarsMutable().append(reductionVariables); + } + rewriter.create<omp::TerminatorOp>(loc); // omp.parallel terminator. + + // The wrapper's entry block arguments will define the reduction + // variables. + llvm::SmallVector<mlir::Type> reductionTypes; + reductionTypes.reserve(reductionVariables.size()); + llvm::transform(reductionVariables, std::back_inserter(reductionTypes), + [](mlir::Value v) { return v.getType(); }); + rewriter.createBlock( + &wsloopOp.getRegion(), {}, reductionTypes, + llvm::SmallVector<mlir::Location>(reductionVariables.size(), + parallelOp.getLoc())); + + rewriter.setInsertionPoint( + rewriter.create<omp::TerminatorOp>(parallelOp.getLoc())); + + // Create loop nest and populate region with contents of scf.parallel. + auto loopOp = rewriter.create<omp::LoopNestOp>( parallelOp.getLoc(), parallelOp.getLowerBound(), parallelOp.getUpperBound(), parallelOp.getStep()); - rewriter.create<omp::TerminatorOp>(loc); - rewriter.inlineRegionBefore(parallelOp.getRegion(), loop.getRegion(), - loop.getRegion().begin()); + rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(), + loopOp.getRegion().begin()); + + // Remove reduction-related block arguments from omp.loop_nest and + // redirect uses to the corresponding omp.wsloop block argument. + mlir::Block &loopOpEntryBlock = loopOp.getRegion().front(); + unsigned numLoops = parallelOp.getNumLoops(); + rewriter.replaceAllUsesWith( + loopOpEntryBlock.getArguments().drop_front(numLoops), + wsloopOp.getRegion().getArguments()); + loopOpEntryBlock.eraseArguments( + numLoops, loopOpEntryBlock.getNumArguments() - numLoops); - Block *ops = rewriter.splitBlock(&*loop.getRegion().begin(), - loop.getRegion().begin()->begin()); + Block *ops = rewriter.splitBlock(&*loopOp.getRegion().begin(), + loopOp.getRegion().begin()->begin()); ---------------- Meinersbur wrote:
With `loopOpEntryBlock`, this can be simplified ```suggestion Block *ops = rewriter.splitBlock(&loopOpEntryBlock, loopOpEntryBlock.begin()); rewriter.setInsertionPointToStart(&loopOpEntryBlock); ``` https://github.com/llvm/llvm-project/pull/89212 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits