================
@@ -3160,6 +3065,98 @@ LogicalResult LoopWrapperInterface::verifyImpl() {
return success();
}
+//===----------------------------------------------------------------------===//
+// ComposableOpInterface
+//===----------------------------------------------------------------------===//
+
+Operation *ComposableOpInterface::findCapturedOp() {
+ Operation *op = this->getOperation();
+
+ // Handle the composite case by returning the wrapped omp.loop_nest.
+ if (auto wrapperOp = dyn_cast<LoopWrapperInterface>(op))
+ return wrapperOp.getWrappedLoop();
+
+ // Do not look further if this op is not combined with any of its children.
+ // Need to check for composite for the omp.parallel case, which is not a loop
+ // wrapper itself.
+ if (!isCombined() && !isComposite())
+ return op;
+
+ Region ®ion = op->getRegion(0);
+ for (Operation &nestedOp : region.getOps()) {
+ if (auto wrapperOp = dyn_cast<LoopWrapperInterface>(&nestedOp))
+ return wrapperOp.getWrappedLoop();
+
+ if (auto composableOp = dyn_cast<ComposableOpInterface>(&nestedOp))
+ return composableOp.findCapturedOp();
+ }
+
+ // This can only be reached if the op has an omp.combined attribute but the
+ // corresponding nested composable op has been deleted. In that case, it's
+ // correct to return this operation.
+ return op;
+}
+
+LogicalResult ComposableOpInterface::verifyImpl() {
+ Operation *op = this->getOperation();
+
+ if (op->getNumRegions() != 1)
+ return emitOpError() << "composable ops must have a single region";
+
+ if (isComposite() && !isa<LoopWrapperInterface, ParallelOp>(op))
+ return emitOpError() << "non-loop wrapper cannot be composite";
+
+ // If combined, must have exactly one eligible nested op (composable or loop
+ // wrapper).
+ if (isCombined()) {
+ Operation *nestedOp = nullptr;
+ auto count = llvm::count_if(
+ op->getRegion(0).getOps(), [&nestedOp](mlir::Operation &op) {
+ if (isa<ComposableOpInterface, LoopWrapperInterface>(op)) {
+ nestedOp = &op;
+ return true;
+ }
+ return false;
+ });
----------------
skatrak wrote:
That's actually one of the big reasons I've gone through the process of making
the frontend responsible for tagging combined constructs. The previous
pattern-based approach actually did something like that: "a couple of nested
OpenMP constructs are combined if there are no other OpenMP or side-effecting
ops, or control flow in the parent op's region".
That regularly broke because of e.g. debug ops that had side effects just to
avoid being optimized out or due to how certain clauses of the child op were
lowered to MLIR, potentially adding memory allocations, etc. The current
pattern (`TargetOp::getInnermostCapturedOmpOp` + `findCapturedOmpOp`) kind of
works around those cases that we've detected so far, while also allowing
operations that should probably break it.
Checks here after this change are more lenient on that sense on purpose. We
just accept that we need to trust the frontend on this because it can't be
reliably checked in MLIR. The thinking here is that a frontend adds the
`omp.combined` attribute if it **knows** it can do so, but it would always
(minus the `target teams ompx_bare` case) be correct to not put it there,
though this could have a potential performance impact.
https://github.com/llvm/llvm-project/pull/198782
_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits