https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/89212
>From fdee8cf17cd7d2dbdb6cf872574776f02e70be7c Mon Sep 17 00:00:00 2001 From: Sergio Afonso <safon...@amd.com> Date: Thu, 18 Apr 2024 10:55:16 +0100 Subject: [PATCH 1/2] [MLIR][SCF] Update scf.parallel lowering to OpenMP (3/5) This patch makes changes to the `scf.parallel` to `omp.parallel` + `omp.wsloop` lowering pass in order to introduce a nested `omp.loop_nest` as well, and to follow the new loop wrapper role for `omp.wsloop`. This PR on its own will not pass premerge tests. All patches in the stack are needed before it can be compiled and passes tests. --- .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 52 ++++++++++++++----- .../Conversion/SCFToOpenMP/reductions.mlir | 5 ++ .../Conversion/SCFToOpenMP/scf-to-openmp.mlir | 31 ++++++++--- 3 files changed, 68 insertions(+), 20 deletions(-) diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index 7f91367ad427a2..f0b8d6c5309357 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -461,18 +461,51 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { // Replace the loop. { OpBuilder::InsertionGuard allocaGuard(rewriter); - auto loop = rewriter.create<omp::WsloopOp>( + // Create worksharing loop wrapper. + auto wsloopOp = rewriter.create<omp::WsloopOp>(parallelOp.getLoc()); + if (!reductionVariables.empty()) { + wsloopOp.setReductionsAttr( + ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols)); + wsloopOp.getReductionVarsMutable().append(reductionVariables); + } + rewriter.create<omp::TerminatorOp>(loc); // omp.parallel terminator. + + // The wrapper's entry block arguments will define the reduction + // variables. + llvm::SmallVector<mlir::Type> reductionTypes; + reductionTypes.reserve(reductionVariables.size()); + llvm::transform(reductionVariables, std::back_inserter(reductionTypes), + [](mlir::Value v) { return v.getType(); }); + rewriter.createBlock( + &wsloopOp.getRegion(), {}, reductionTypes, + llvm::SmallVector<mlir::Location>(reductionVariables.size(), + parallelOp.getLoc())); + + rewriter.setInsertionPoint( + rewriter.create<omp::TerminatorOp>(parallelOp.getLoc())); + + // Create loop nest and populate region with contents of scf.parallel. + auto loopOp = rewriter.create<omp::LoopNestOp>( parallelOp.getLoc(), parallelOp.getLowerBound(), parallelOp.getUpperBound(), parallelOp.getStep()); - rewriter.create<omp::TerminatorOp>(loc); - rewriter.inlineRegionBefore(parallelOp.getRegion(), loop.getRegion(), - loop.getRegion().begin()); + rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(), + loopOp.getRegion().begin()); + + // Remove reduction-related block arguments from omp.loop_nest and + // redirect uses to the corresponding omp.wsloop block argument. + mlir::Block &loopOpEntryBlock = loopOp.getRegion().front(); + unsigned numLoops = parallelOp.getNumLoops(); + rewriter.replaceAllUsesWith( + loopOpEntryBlock.getArguments().drop_front(numLoops), + wsloopOp.getRegion().getArguments()); + loopOpEntryBlock.eraseArguments( + numLoops, loopOpEntryBlock.getNumArguments() - numLoops); - Block *ops = rewriter.splitBlock(&*loop.getRegion().begin(), - loop.getRegion().begin()->begin()); + Block *ops = rewriter.splitBlock(&*loopOp.getRegion().begin(), + loopOp.getRegion().begin()->begin()); - rewriter.setInsertionPointToStart(&*loop.getRegion().begin()); + rewriter.setInsertionPointToStart(&*loopOp.getRegion().begin()); auto scope = rewriter.create<memref::AllocaScopeOp>(parallelOp.getLoc(), TypeRange()); @@ -481,11 +514,6 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { rewriter.mergeBlocks(ops, scopeBlock); rewriter.setInsertionPointToEnd(&*scope.getBodyRegion().begin()); rewriter.create<memref::AllocaScopeReturnOp>(loc, ValueRange()); - if (!reductionVariables.empty()) { - loop.setReductionsAttr( - ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols)); - loop.getReductionVarsMutable().append(reductionVariables); - } } } diff --git a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir index 3b6c145d62f1a8..fc6d56559c2618 100644 --- a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir +++ b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir @@ -28,6 +28,7 @@ func.func @reduction1(%arg0 : index, %arg1 : index, %arg2 : index, // CHECK: omp.parallel // CHECK: omp.wsloop // CHECK-SAME: reduction(@[[$REDF]] %[[BUF]] -> %[[PVT_BUF:[a-z0-9]+]] + // CHECK: omp.loop_nest // CHECK: memref.alloca_scope scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %step) init (%zero) -> (f32) { @@ -43,6 +44,7 @@ func.func @reduction1(%arg0 : index, %arg1 : index, %arg2 : index, } // CHECK: omp.yield } + // CHECK: omp.terminator // CHECK: omp.terminator // CHECK: llvm.load %[[BUF]] return @@ -107,6 +109,7 @@ func.func @reduction_muli(%arg0 : index, %arg1 : index, %arg2 : index, %one = arith.constant 1 : i32 // CHECK: %[[RED_VAR:.*]] = llvm.alloca %{{.*}} x i32 : (i64) -> !llvm.ptr // CHECK: omp.wsloop reduction(@[[$REDI]] %[[RED_VAR]] -> %[[RED_PVT_VAR:.*]] : !llvm.ptr) + // CHECK: omp.loop_nest scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %step) init (%one) -> (i32) { // CHECK: %[[C2:.*]] = arith.constant 2 : i32 @@ -208,6 +211,7 @@ func.func @reduction4(%arg0 : index, %arg1 : index, %arg2 : index, // CHECK: omp.wsloop // CHECK-SAME: reduction(@[[$REDF1]] %[[BUF1]] -> %[[PVT_BUF1:[a-z0-9]+]] // CHECK-SAME: @[[$REDF2]] %[[BUF2]] -> %[[PVT_BUF2:[a-z0-9]+]] + // CHECK: omp.loop_nest // CHECK: memref.alloca_scope %res:2 = scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %step) init (%zero, %ione) -> (f32, i64) { @@ -236,6 +240,7 @@ func.func @reduction4(%arg0 : index, %arg1 : index, %arg2 : index, } // CHECK: omp.yield } + // CHECK: omp.terminator // CHECK: omp.terminator // CHECK: %[[RES1:.*]] = llvm.load %[[BUF1]] : !llvm.ptr -> f32 // CHECK: %[[RES2:.*]] = llvm.load %[[BUF2]] : !llvm.ptr -> i64 diff --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir index acd2690c56e2e6..b2f19d294cb5fe 100644 --- a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir +++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir @@ -2,10 +2,11 @@ // CHECK-LABEL: @parallel func.func @parallel(%arg0: index, %arg1: index, %arg2: index, - %arg3: index, %arg4: index, %arg5: index) { + %arg3: index, %arg4: index, %arg5: index) { // CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32 // CHECK: omp.parallel num_threads(%[[FOUR]] : i32) { - // CHECK: omp.wsloop for (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) { + // CHECK: omp.wsloop { + // CHECK: omp.loop_nest (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) { // CHECK: memref.alloca_scope scf.parallel (%i, %j) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) { // CHECK: "test.payload"(%[[LVAR1]], %[[LVAR2]]) : (index, index) -> () @@ -13,6 +14,8 @@ func.func @parallel(%arg0: index, %arg1: index, %arg2: index, // CHECK: omp.yield // CHECK: } } + // CHECK: omp.terminator + // CHECK: } // CHECK: omp.terminator // CHECK: } return @@ -23,20 +26,26 @@ func.func @nested_loops(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) { // CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32 // CHECK: omp.parallel num_threads(%[[FOUR]] : i32) { - // CHECK: omp.wsloop for (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) { - // CHECK: memref.alloca_scope + // CHECK: omp.wsloop { + // CHECK: omp.loop_nest (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) { + // CHECK: memref.alloca_scope scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) { // CHECK: omp.parallel - // CHECK: omp.wsloop for (%[[LVAR_IN1:.*]]) : index = (%arg1) to (%arg3) step (%arg5) { + // CHECK: omp.wsloop { + // CHECK: omp.loop_nest (%[[LVAR_IN1:.*]]) : index = (%arg1) to (%arg3) step (%arg5) { // CHECK: memref.alloca_scope scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) { // CHECK: "test.payload"(%[[LVAR_OUT1]], %[[LVAR_IN1]]) : (index, index) -> () "test.payload"(%i, %j) : (index, index) -> () // CHECK: } } - // CHECK: omp.yield + // CHECK: omp.yield + // CHECK: } + // CHECK: omp.terminator // CHECK: } } + // CHECK: omp.terminator + // CHECK: } // CHECK: omp.terminator // CHECK: } return @@ -47,7 +56,8 @@ func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) { // CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32 // CHECK: omp.parallel num_threads(%[[FOUR]] : i32) { - // CHECK: omp.wsloop for (%[[LVAR_AL1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) { + // CHECK: omp.wsloop { + // CHECK: omp.loop_nest (%[[LVAR_AL1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) { // CHECK: memref.alloca_scope scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) { // CHECK: "test.payload1"(%[[LVAR_AL1]]) : (index) -> () @@ -55,12 +65,15 @@ func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index, // CHECK: omp.yield // CHECK: } } + // CHECK: omp.terminator + // CHECK: } // CHECK: omp.terminator // CHECK: } // CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32 // CHECK: omp.parallel num_threads(%[[FOUR]] : i32) { - // CHECK: omp.wsloop for (%[[LVAR_AL2:.*]]) : index = (%arg1) to (%arg3) step (%arg5) { + // CHECK: omp.wsloop { + // CHECK: omp.loop_nest (%[[LVAR_AL2:.*]]) : index = (%arg1) to (%arg3) step (%arg5) { // CHECK: memref.alloca_scope scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) { // CHECK: "test.payload2"(%[[LVAR_AL2]]) : (index) -> () @@ -68,6 +81,8 @@ func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index, // CHECK: omp.yield // CHECK: } } + // CHECK: omp.terminator + // CHECK: } // CHECK: omp.terminator // CHECK: } return >From f8c0897e775252b153899d685b774ef15e8759bb Mon Sep 17 00:00:00 2001 From: Sergio Afonso <safon...@amd.com> Date: Fri, 19 Apr 2024 13:14:59 +0100 Subject: [PATCH 2/2] Address review comments --- mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index f0b8d6c5309357..d6f85451ee5d30 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -502,10 +502,9 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { loopOpEntryBlock.eraseArguments( numLoops, loopOpEntryBlock.getNumArguments() - numLoops); - Block *ops = rewriter.splitBlock(&*loopOp.getRegion().begin(), - loopOp.getRegion().begin()->begin()); - - rewriter.setInsertionPointToStart(&*loopOp.getRegion().begin()); + Block *ops = + rewriter.splitBlock(&loopOpEntryBlock, loopOpEntryBlock.begin()); + rewriter.setInsertionPointToStart(&loopOpEntryBlock); auto scope = rewriter.create<memref::AllocaScopeOp>(parallelOp.getLoc(), TypeRange()); _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits