================
@@ -1805,11 +1813,212 @@ class CIRTryOpFlattening : public
mlir::OpRewritePattern<cir::TryOp> {
}
};
+static mlir::Block *getOrCreateBlockForSuspendPoint(
+ cir::FuncOp funcOp, mlir::PatternRewriter &rewriter, mlir::Location loc) {
+ mlir::Block &entryBlock = funcOp.getBody().front();
+
+ auto it = llvm::find_if(entryBlock, [](auto &op) {
+ return mlir::isa<AllocaOp>(&op) &&
+ mlir::cast<AllocaOp>(&op).getCoroutineSuspendPoint();
+ });
+
+ assert(it->hasOneUse() &&
+ "coroutine suspend point alloca must have exactly one use");
+ auto storeOp = cast<cir::StoreOp>(*it->getUses().begin()->getOwner());
+ auto suspendPoint =
cast<cir::ConstantOp>(storeOp.getValue().getDefiningOp());
+ mlir::Block *suspendBlock = suspendPoint->getBlock();
+ if (&suspendBlock->front() == suspendPoint)
+ return suspendBlock;
+
+ mlir::OpBuilder::InsertionGuard guard(rewriter);
+ mlir::Block *remainingBlock =
+ rewriter.splitBlock(suspendBlock, suspendPoint->getIterator());
+ rewriter.setInsertionPointToEnd(suspendBlock);
+ cir::BrOp::create(rewriter, loc, remainingBlock);
+ return remainingBlock;
+}
+
+class CIRAwaitOpFlattening : public mlir::OpRewritePattern<cir::AwaitOp> {
+public:
+ using OpRewritePattern<cir::AwaitOp>::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(cir::AwaitOp awaitOp,
+ mlir::PatternRewriter &rewriter) const override {
+ mlir::Block *awaitBlock = rewriter.getInsertionBlock();
+ mlir::Block *remainingOpsBlock =
+ rewriter.splitBlock(awaitBlock, rewriter.getInsertionPoint());
+
+ mlir::Location loc = awaitOp.getLoc();
+
+ mlir::Region &readyRegion = awaitOp.getReady();
+ mlir::Block &beforeReady = awaitOp.getReady().front();
+ mlir::Region &suspendRegion = awaitOp.getSuspend();
+ mlir::Region &resumeRegion = awaitOp.getResume();
+ auto conditionOp =
+ cast<cir::ConditionOp>(readyRegion.back().getTerminator());
+ {
+ mlir::OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(conditionOp);
+ rewriter.replaceOpWithNewOp<cir::BrCondOp>(
+ conditionOp, conditionOp.getCondition(), &resumeRegion.front(),
+ &suspendRegion.front());
+ }
+ rewriter.inlineRegionBefore(readyRegion, remainingOpsBlock);
+
+ {
+ mlir::OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToEnd(awaitBlock);
+ cir::BrOp::create(rewriter, loc, mlir::ValueRange(), &beforeReady);
+ }
+
+ auto suspendYield =
+ cast<cir::YieldOp>(suspendRegion.back().getTerminator());
+ cir::LLVMIntrinsicCallOp coroSuspendIntri = nullptr;
+ {
+ mlir::OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(&suspendRegion.front().front());
+
+ // Insert coro.save at the beginning of the suspend region.
+ // This captures the current coroutine state before suspension.
+ auto voidPtrTy = cir::PointerType::get(cir::VoidType::get(getContext()));
+ auto nullPtr = cir::ConstantOp::create(
+ rewriter, loc,
+ cir::ConstPtrAttr::get(voidPtrTy, rewriter.getI64IntegerAttr(0)));
+ auto coroSaveIntri = cir::LLVMIntrinsicCallOp::create(
+ rewriter, loc, mlir::StringAttr::get(getContext(), "llvm.coro.save"),
+ cir::IntType::get(getContext(), 32, false),
+ mlir::ValueRange{nullPtr});
+ rewriter.setInsertionPoint(suspendYield);
+
+ bool isFinalSuspend = awaitOp.getKind() == cir::AwaitKind::Final;
+ auto isFinalCoroSuspend = cir::ConstantOp::create(
+ rewriter, loc, cir::BoolAttr::get(getContext(), isFinalSuspend));
+
+ // llvm.coro.suspend returns:
+ // -1 : coroutine suspended
+ // 0 : coroutine resumed
+ // 1 : coroutine destroyed
+ coroSuspendIntri = cir::LLVMIntrinsicCallOp::create(
+ rewriter, loc,
+ mlir::StringAttr::get(getContext(), "llvm.coro.suspend"),
+ cir::IntType::get(getContext(), 32, false),
+ mlir::ValueRange{coroSaveIntri.getResult(), isFinalCoroSuspend});
+ }
+ rewriter.inlineRegionBefore(suspendRegion, remainingOpsBlock);
+
+ auto func = awaitOp->getParentOfType<cir::FuncOp>();
+
+ {
+ mlir::OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(suspendYield);
+ llvm::SmallVector<mlir::APInt, 2> caseValues{mlir::APInt(32, 0),
+ mlir::APInt(32, 1)};
+
+ llvm::SmallVector<mlir::ValueRange, 8> caseOperands{
----------------
Andres-Salamanca wrote:
I don't remember why I used 8 here 😅. I've changed it.
https://github.com/llvm/llvm-project/pull/203802
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits