https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/87232
>From 2452bc75a7f2efb67a0522bbe8b0e7ba5bc3365b Mon Sep 17 00:00:00 2001 From: Sergio Afonso <safon...@amd.com> Date: Mon, 1 Apr 2024 13:04:14 +0100 Subject: [PATCH 1/2] [MLIR][OpenMP] Introduce the LoopWrapperInterface This patch defines a common interface to be shared by all OpenMP loop wrapper operations. The main restrictions these operations must meet in order to be considered a wrapper are: - They contain a single region. - Their region contains a single block. - Their block only contains another loop wrapper or `omp.loop_nest` and a terminator. The new interface is attached to the `omp.parallel`, `omp.wsloop`, `omp.simdloop`, `omp.distribute` and `omp.taskloop` operations. It is not currently enforced that these operations meet the wrapper restrictions, which would break existing OpenMP loop-generating code. Rather, this will be introduced progressively in subsequent patches. --- .../mlir/Dialect/OpenMP/OpenMPInterfaces.h | 3 + mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 16 +++-- .../Dialect/OpenMP/OpenMPOpsInterfaces.td | 68 +++++++++++++++++++ mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 19 ++++++ mlir/test/Dialect/OpenMP/invalid.mlir | 16 ++++- 5 files changed, 117 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPInterfaces.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPInterfaces.h index b3184db8852161..787c48b05c5c5c 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPInterfaces.h +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPInterfaces.h @@ -21,6 +21,9 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#define GET_OP_FWD_DEFINES +#include "mlir/Dialect/OpenMP/OpenMPOps.h.inc" + #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.h.inc" namespace mlir::omp { diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index ffd00948915153..a7bf93deae2fb3 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -236,6 +236,7 @@ def PrivateClauseOp : OpenMP_Op<"private", [IsolatedFromAbove]> { def ParallelOp : OpenMP_Op<"parallel", [ AutomaticAllocationScope, AttrSizedOperandSegments, + DeclareOpInterfaceMethods<LoopWrapperInterface>, DeclareOpInterfaceMethods<OutlineableOpenMPOpInterface>, RecursiveMemoryEffects, ReductionClauseInterface]> { let summary = "parallel construct"; @@ -517,8 +518,6 @@ def SingleOp : OpenMP_Op<"single", [AttrSizedOperandSegments]> { def LoopNestOp : OpenMP_Op<"loop_nest", [SameVariadicOperandSize, AllTypesMatch<["lowerBound", "upperBound", "step"]>, - ParentOneOf<["DistributeOp", "SimdLoopOp", "TaskloopOp", - "WsloopOp"]>, RecursiveMemoryEffects]> { let summary = "rectangular loop nest"; let description = [{ @@ -568,6 +567,10 @@ def LoopNestOp : OpenMP_Op<"loop_nest", [SameVariadicOperandSize, /// Returns the induction variables of the loop nest. ArrayRef<BlockArgument> getIVs() { return getRegion().getArguments(); } + + /// Returns the list of wrapper operations around this loop nest. Wrappers + /// in the resulting vector will be sorted from innermost to outermost. + SmallVector<LoopWrapperInterface> getWrappers(); }]; let hasCustomAssemblyFormat = 1; @@ -580,6 +583,7 @@ def LoopNestOp : OpenMP_Op<"loop_nest", [SameVariadicOperandSize, def WsloopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments, AllTypesMatch<["lowerBound", "upperBound", "step"]>, + DeclareOpInterfaceMethods<LoopWrapperInterface>, RecursiveMemoryEffects, ReductionClauseInterface]> { let summary = "worksharing-loop construct"; let description = [{ @@ -700,7 +704,9 @@ def WsloopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments, //===----------------------------------------------------------------------===// def SimdLoopOp : OpenMP_Op<"simdloop", [AttrSizedOperandSegments, - AllTypesMatch<["lowerBound", "upperBound", "step"]>]> { + AllTypesMatch<["lowerBound", "upperBound", "step"]>, + DeclareOpInterfaceMethods<LoopWrapperInterface>, + RecursiveMemoryEffects]> { let summary = "simd loop construct"; let description = [{ The simd construct can be applied to a loop to indicate that the loop can be @@ -809,7 +815,8 @@ def YieldOp : OpenMP_Op<"yield", // Distribute construct [2.9.4.1] //===----------------------------------------------------------------------===// def DistributeOp : OpenMP_Op<"distribute", [AttrSizedOperandSegments, - MemoryEffects<[MemWrite]>]> { + DeclareOpInterfaceMethods<LoopWrapperInterface>, + RecursiveMemoryEffects]> { let summary = "distribute construct"; let description = [{ The distribute construct specifies that the iterations of one or more loops @@ -980,6 +987,7 @@ def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments, def TaskloopOp : OpenMP_Op<"taskloop", [AttrSizedOperandSegments, AutomaticAllocationScope, RecursiveMemoryEffects, AllTypesMatch<["lowerBound", "upperBound", "step"]>, + DeclareOpInterfaceMethods<LoopWrapperInterface>, ReductionClauseInterface]> { let summary = "taskloop construct"; let description = [{ diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td index 2e37384ce3eb71..b6a3560b7da56a 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td @@ -69,6 +69,74 @@ def ReductionClauseInterface : OpInterface<"ReductionClauseInterface"> { ]; } +def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> { + let description = [{ + OpenMP operations that can wrap a single loop nest. When taking a wrapper + role, these operations must only contain a single region with a single block + in which there's a single operation and a terminator. That nested operation + must be another loop wrapper or an `omp.loop_nest`. + }]; + + let cppNamespace = "::mlir::omp"; + + let methods = [ + InterfaceMethod< + /*description=*/[{ + Tell whether the operation could be taking the role of a loop wrapper. + That is, it has a single region with a single block in which there are + two operations: another wrapper or `omp.loop_nest` operation and a + terminator. + }], + /*retTy=*/"bool", + /*methodName=*/"isWrapper", + (ins ), [{}], [{ + if ($_op->getNumRegions() != 1) + return false; + + ::mlir::Region &r = $_op->getRegion(0); + if (!r.hasOneBlock()) + return false; + + if (std::distance(r.op_begin(), r.op_end()) != 2) + return false; + + ::mlir::Operation &firstOp = *r.op_begin(); + ::mlir::Operation &secondOp = *(++r.op_begin()); + return ::llvm::isa<::mlir::omp::LoopNestOp, + ::mlir::omp::LoopWrapperInterface>(firstOp) && + secondOp.hasTrait<::mlir::OpTrait::IsTerminator>(); + }] + >, + InterfaceMethod< + /*description=*/[{ + If there is another loop wrapper immediately nested inside, return that + operation. Assumes this operation is taking a loop wrapper role. + }], + /*retTy=*/"::mlir::omp::LoopWrapperInterface", + /*methodName=*/"getNestedWrapper", + (ins), [{}], [{ + assert($_op.isWrapper() && "Unexpected non-wrapper op"); + ::mlir::Operation *nested = &*$_op->getRegion(0).op_begin(); + return ::llvm::dyn_cast<::mlir::omp::LoopWrapperInterface>(nested); + }] + >, + InterfaceMethod< + /*description=*/[{ + Return the loop nest nested directly or indirectly inside of this loop + wrapper. Assumes this operation is taking a loop wrapper role. + }], + /*retTy=*/"::mlir::Operation *", + /*methodName=*/"getWrappedLoop", + (ins), [{}], [{ + assert($_op.isWrapper() && "Unexpected non-wrapper op"); + if (::mlir::omp::LoopWrapperInterface nested = $_op.getNestedWrapper()) + return nested.getWrappedLoop(); + return &*$_op->getRegion(0).op_begin(); + }] + > + ]; +} + def DeclareTargetInterface : OpInterface<"DeclareTargetInterface"> { let description = [{ OpenMP operations that support declare target have this interface. diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 796df1d13e6564..564c23201db4fd 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1730,9 +1730,28 @@ LogicalResult LoopNestOp::verify() { << "range argument type does not match corresponding IV type"; } + auto wrapper = + llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()); + + if (!wrapper || !wrapper.isWrapper()) + return emitOpError() << "expects parent op to be a valid loop wrapper"; + return success(); } +SmallVector<LoopWrapperInterface> LoopNestOp::getWrappers() { + SmallVector<LoopWrapperInterface> wrappers; + Operation *parent = (*this)->getParentOp(); + while (auto wrapper = + llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) { + if (!wrapper.isWrapper()) + break; + wrappers.push_back(wrapper); + parent = parent->getParentOp(); + } + return wrappers; +} + //===----------------------------------------------------------------------===// // WsloopOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 760ebb14d94121..8f4103dabee5df 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -88,7 +88,7 @@ func.func @proc_bind_once() { // ----- func.func @invalid_parent(%lb : index, %ub : index, %step : index) { - // expected-error@+1 {{op expects parent op to be one of 'omp.distribute, omp.simdloop, omp.taskloop, omp.wsloop'}} + // expected-error@+1 {{op expects parent op to be a valid loop wrapper}} omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) { omp.yield } @@ -96,6 +96,20 @@ func.func @invalid_parent(%lb : index, %ub : index, %step : index) { // ----- +func.func @invalid_wrapper(%lb : index, %ub : index, %step : index) { + // TODO Remove induction variables from omp.wsloop. + omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) { + %0 = arith.constant 0 : i32 + // expected-error@+1 {{op expects parent op to be a valid loop wrapper}} + omp.loop_nest (%iv2) : index = (%lb) to (%ub) step (%step) { + omp.yield + } + omp.yield + } +} + +// ----- + func.func @type_mismatch(%lb : index, %ub : index, %step : index) { // TODO Remove induction variables from omp.wsloop. omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) { >From 904f27489b0d3c27e773f085b18a9b85cb548f44 Mon Sep 17 00:00:00 2001 From: Sergio Afonso <safon...@amd.com> Date: Tue, 2 Apr 2024 15:36:41 +0100 Subject: [PATCH 2/2] Address review comments --- mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 4 ++-- .../Dialect/OpenMP/OpenMPOpsInterfaces.td | 19 +++++++++---------- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 5 ++--- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index a7bf93deae2fb3..50627712ea3109 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -568,9 +568,9 @@ def LoopNestOp : OpenMP_Op<"loop_nest", [SameVariadicOperandSize, /// Returns the induction variables of the loop nest. ArrayRef<BlockArgument> getIVs() { return getRegion().getArguments(); } - /// Returns the list of wrapper operations around this loop nest. Wrappers + /// Fills a list of wrapper operations around this loop nest. Wrappers /// in the resulting vector will be sorted from innermost to outermost. - SmallVector<LoopWrapperInterface> getWrappers(); + void gatherWrappers(SmallVectorImpl<LoopWrapperInterface> &wrappers); }]; let hasCustomAssemblyFormat = 1; diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td index b6a3560b7da56a..ab9b78e755d9d5 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td @@ -93,18 +93,17 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> { if ($_op->getNumRegions() != 1) return false; - ::mlir::Region &r = $_op->getRegion(0); + Region &r = $_op->getRegion(0); if (!r.hasOneBlock()) return false; - if (std::distance(r.op_begin(), r.op_end()) != 2) + if (::llvm::range_size(r.getOps()) != 2) return false; - ::mlir::Operation &firstOp = *r.op_begin(); - ::mlir::Operation &secondOp = *(++r.op_begin()); - return ::llvm::isa<::mlir::omp::LoopNestOp, - ::mlir::omp::LoopWrapperInterface>(firstOp) && - secondOp.hasTrait<::mlir::OpTrait::IsTerminator>(); + Operation &firstOp = *r.op_begin(); + Operation &secondOp = *(std::next(r.op_begin())); + return ::llvm::isa<LoopNestOp, LoopWrapperInterface>(firstOp) && + secondOp.hasTrait<OpTrait::IsTerminator>(); }] >, InterfaceMethod< @@ -116,8 +115,8 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> { /*methodName=*/"getNestedWrapper", (ins), [{}], [{ assert($_op.isWrapper() && "Unexpected non-wrapper op"); - ::mlir::Operation *nested = &*$_op->getRegion(0).op_begin(); - return ::llvm::dyn_cast<::mlir::omp::LoopWrapperInterface>(nested); + Operation *nested = &*$_op->getRegion(0).op_begin(); + return ::llvm::dyn_cast<LoopWrapperInterface>(nested); }] >, InterfaceMethod< @@ -129,7 +128,7 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> { /*methodName=*/"getWrappedLoop", (ins), [{}], [{ assert($_op.isWrapper() && "Unexpected non-wrapper op"); - if (::mlir::omp::LoopWrapperInterface nested = $_op.getNestedWrapper()) + if (LoopWrapperInterface nested = $_op.getNestedWrapper()) return nested.getWrappedLoop(); return &*$_op->getRegion(0).op_begin(); }] diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 564c23201db4fd..a7d265328df6ef 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1739,8 +1739,8 @@ LogicalResult LoopNestOp::verify() { return success(); } -SmallVector<LoopWrapperInterface> LoopNestOp::getWrappers() { - SmallVector<LoopWrapperInterface> wrappers; +void LoopNestOp::gatherWrappers( + SmallVectorImpl<LoopWrapperInterface> &wrappers) { Operation *parent = (*this)->getParentOp(); while (auto wrapper = llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) { @@ -1749,7 +1749,6 @@ SmallVector<LoopWrapperInterface> LoopNestOp::getWrappers() { wrappers.push_back(wrapper); parent = parent->getParentOp(); } - return wrappers; } //===----------------------------------------------------------------------===// _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits