llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir Author: Tom Eccles (tblah) <details> <summary>Changes</summary> Taskloop support will follow in a later patch. --- Full diff: https://github.com/llvm/llvm-project/pull/137194.diff 3 Files Affected: - (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+38-2) - (modified) mlir/test/Target/LLVMIR/openmp-cancel.mlir (+87) - (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (-16) ``````````diff diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index d1885641f389d..7d8a7ccb6e4ac 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -161,8 +161,7 @@ static LogicalResult checkImplementationStatus(Operation &op) { auto checkCancelDirective = [&todo](auto op, LogicalResult &result) { omp::ClauseCancellationConstructType cancelledDirective = op.getCancelDirective(); - if (cancelledDirective != omp::ClauseCancellationConstructType::Parallel && - cancelledDirective != omp::ClauseCancellationConstructType::Sections) + if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup) result = todo("cancel directive"); }; auto checkDepend = [&todo](auto op, LogicalResult &result) { @@ -2360,6 +2359,30 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop : llvm::omp::WorksharingLoopType::ForStaticLoop; + SmallVector<llvm::BranchInst *> cancelTerminators; + // This callback is invoked only if there is cancellation inside of the wsloop + // body. + auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error { + llvm::IRBuilderBase &llvmBuilder = ompBuilder->Builder; + llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder); + + // ip is currently in the block branched to if cancellation occured. + // We need to create a branch to terminate that block. + llvmBuilder.restoreIP(ip); + + // We must still clean up the wsloop after cancelling it, so we need to + // branch to the block that finalizes the wsloop. + // That block has not been created yet so use this block as a dummy for now + // and fix this after creating the wsloop. + cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock())); + return llvm::Error::success(); + }; + // We have to add the cleanup to the OpenMPIRBuilder before the body gets + // created in case the body contains omp.cancel (which will then expect to be + // able to find this cleanup callback). + ompBuilder->pushFinalizationCB({finiCB, llvm::omp::Directive::OMPD_for, + constructIsCancellable(wsloopOp)}); + llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions( wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation); @@ -2381,6 +2404,19 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, if (failed(handleError(wsloopIP, opInst))) return failure(); + ompBuilder->popFinalizationCB(); + if (!cancelTerminators.empty()) { + // If we cancelled the loop, we should branch to the finalization block of + // the wsloop (which is always immediately before the loop continuation + // block). Now the finalization has been created, we can fix the branch. + llvm::BasicBlock *wsloopFini = wsloopIP->getBlock()->getSinglePredecessor(); + for (llvm::BranchInst *cancelBranch : cancelTerminators) { + assert(cancelBranch->getNumSuccessors() == 1 && + "cancel branch should have one target"); + cancelBranch->setSuccessor(0, wsloopFini); + } + } + // Process the reductions if required. if (failed(createReductionsAndCleanup( wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls, diff --git a/mlir/test/Target/LLVMIR/openmp-cancel.mlir b/mlir/test/Target/LLVMIR/openmp-cancel.mlir index fca16b936fc85..3c195a98d1000 100644 --- a/mlir/test/Target/LLVMIR/openmp-cancel.mlir +++ b/mlir/test/Target/LLVMIR/openmp-cancel.mlir @@ -156,3 +156,90 @@ llvm.func @cancel_sections_if(%cond : i1) { // CHECK: ret void // CHECK: .cncl: ; preds = %[[VAL_27]] // CHECK: br label %[[VAL_19]] + +llvm.func @cancel_wsloop_if(%lb : i32, %ub : i32, %step : i32, %cond : i1) { + omp.wsloop { + omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { + omp.cancel cancellation_construct_type(loop) if(%cond) + omp.yield + } + } + llvm.return +} +// CHECK-LABEL: define void @cancel_wsloop_if +// CHECK: %[[VAL_0:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_1:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_2:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_3:.*]] = alloca i32, align 4 +// CHECK: br label %[[VAL_4:.*]] +// CHECK: omp.region.after_alloca: ; preds = %[[VAL_5:.*]] +// CHECK: br label %[[VAL_6:.*]] +// CHECK: entry: ; preds = %[[VAL_4]] +// CHECK: br label %[[VAL_7:.*]] +// CHECK: omp.wsloop.region: ; preds = %[[VAL_6]] +// CHECK: %[[VAL_8:.*]] = icmp slt i32 %[[VAL_9:.*]], 0 +// CHECK: %[[VAL_10:.*]] = sub i32 0, %[[VAL_9]] +// CHECK: %[[VAL_11:.*]] = select i1 %[[VAL_8]], i32 %[[VAL_10]], i32 %[[VAL_9]] +// CHECK: %[[VAL_12:.*]] = select i1 %[[VAL_8]], i32 %[[VAL_13:.*]], i32 %[[VAL_14:.*]] +// CHECK: %[[VAL_15:.*]] = select i1 %[[VAL_8]], i32 %[[VAL_14]], i32 %[[VAL_13]] +// CHECK: %[[VAL_16:.*]] = sub nsw i32 %[[VAL_15]], %[[VAL_12]] +// CHECK: %[[VAL_17:.*]] = icmp sle i32 %[[VAL_15]], %[[VAL_12]] +// CHECK: %[[VAL_18:.*]] = sub i32 %[[VAL_16]], 1 +// CHECK: %[[VAL_19:.*]] = udiv i32 %[[VAL_18]], %[[VAL_11]] +// CHECK: %[[VAL_20:.*]] = add i32 %[[VAL_19]], 1 +// CHECK: %[[VAL_21:.*]] = icmp ule i32 %[[VAL_16]], %[[VAL_11]] +// CHECK: %[[VAL_22:.*]] = select i1 %[[VAL_21]], i32 1, i32 %[[VAL_20]] +// CHECK: %[[VAL_23:.*]] = select i1 %[[VAL_17]], i32 0, i32 %[[VAL_22]] +// CHECK: br label %[[VAL_24:.*]] +// CHECK: omp_loop.preheader: ; preds = %[[VAL_7]] +// CHECK: store i32 0, ptr %[[VAL_1]], align 4 +// CHECK: %[[VAL_25:.*]] = sub i32 %[[VAL_23]], 1 +// CHECK: store i32 %[[VAL_25]], ptr %[[VAL_2]], align 4 +// CHECK: store i32 1, ptr %[[VAL_3]], align 4 +// CHECK: %[[VAL_26:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) +// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %[[VAL_26]], i32 34, ptr %[[VAL_0]], ptr %[[VAL_1]], ptr %[[VAL_2]], ptr %[[VAL_3]], i32 1, i32 0) +// CHECK: %[[VAL_27:.*]] = load i32, ptr %[[VAL_1]], align 4 +// CHECK: %[[VAL_28:.*]] = load i32, ptr %[[VAL_2]], align 4 +// CHECK: %[[VAL_29:.*]] = sub i32 %[[VAL_28]], %[[VAL_27]] +// CHECK: %[[VAL_30:.*]] = add i32 %[[VAL_29]], 1 +// CHECK: br label %[[VAL_31:.*]] +// CHECK: omp_loop.header: ; preds = %[[VAL_32:.*]], %[[VAL_24]] +// CHECK: %[[VAL_33:.*]] = phi i32 [ 0, %[[VAL_24]] ], [ %[[VAL_34:.*]], %[[VAL_32]] ] +// CHECK: br label %[[VAL_35:.*]] +// CHECK: omp_loop.cond: ; preds = %[[VAL_31]] +// CHECK: %[[VAL_36:.*]] = icmp ult i32 %[[VAL_33]], %[[VAL_30]] +// CHECK: br i1 %[[VAL_36]], label %[[VAL_37:.*]], label %[[VAL_38:.*]] +// CHECK: omp_loop.body: ; preds = %[[VAL_35]] +// CHECK: %[[VAL_39:.*]] = add i32 %[[VAL_33]], %[[VAL_27]] +// CHECK: %[[VAL_40:.*]] = mul i32 %[[VAL_39]], %[[VAL_9]] +// CHECK: %[[VAL_41:.*]] = add i32 %[[VAL_40]], %[[VAL_14]] +// CHECK: br label %[[VAL_42:.*]] +// CHECK: omp.loop_nest.region: ; preds = %[[VAL_37]] +// CHECK: br i1 %[[VAL_43:.*]], label %[[VAL_44:.*]], label %[[VAL_45:.*]] +// CHECK: 25: ; preds = %[[VAL_42]] +// CHECK: %[[VAL_46:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) +// CHECK: %[[VAL_47:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_46]], i32 2) +// CHECK: %[[VAL_48:.*]] = icmp eq i32 %[[VAL_47]], 0 +// CHECK: br i1 %[[VAL_48]], label %[[VAL_49:.*]], label %[[VAL_50:.*]] +// CHECK: .split: ; preds = %[[VAL_44]] +// CHECK: br label %[[VAL_51:.*]] +// CHECK: 28: ; preds = %[[VAL_42]] +// CHECK: br label %[[VAL_51]] +// CHECK: 29: ; preds = %[[VAL_45]], %[[VAL_49]] +// CHECK: br label %[[VAL_52:.*]] +// CHECK: omp.region.cont1: ; preds = %[[VAL_51]] +// CHECK: br label %[[VAL_32]] +// CHECK: omp_loop.inc: ; preds = %[[VAL_52]] +// CHECK: %[[VAL_34]] = add nuw i32 %[[VAL_33]], 1 +// CHECK: br label %[[VAL_31]] +// CHECK: omp_loop.exit: ; preds = %[[VAL_50]], %[[VAL_35]] +// CHECK: call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_26]]) +// CHECK: %[[VAL_53:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) +// CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_53]]) +// CHECK: br label %[[VAL_54:.*]] +// CHECK: omp_loop.after: ; preds = %[[VAL_38]] +// CHECK: br label %[[VAL_55:.*]] +// CHECK: omp.region.cont: ; preds = %[[VAL_54]] +// CHECK: ret void +// CHECK: .cncl: ; preds = %[[VAL_44]] +// CHECK: br label %[[VAL_38]] diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir index 6f904d0647285..f8d720dfe420c 100644 --- a/mlir/test/Target/LLVMIR/openmp-todo.mlir +++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir @@ -26,22 +26,6 @@ llvm.func @atomic_hint(%v : !llvm.ptr, %x : !llvm.ptr, %expr : i32) { // ----- -llvm.func @cancel_wsloop(%lb : i32, %ub : i32, %step: i32) { - // expected-error@below {{LLVM Translation failed for operation: omp.wsloop}} - omp.wsloop { - // expected-error@below {{LLVM Translation failed for operation: omp.loop_nest}} - omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { - // expected-error@below {{not yet implemented: Unhandled clause cancel directive in omp.cancel operation}} - // expected-error@below {{LLVM Translation failed for operation: omp.cancel}} - omp.cancel cancellation_construct_type(loop) - omp.yield - } - } - llvm.return -} - -// ----- - llvm.func @cancel_taskgroup() { // expected-error@below {{LLVM Translation failed for operation: omp.taskgroup}} omp.taskgroup { `````````` </details> https://github.com/llvm/llvm-project/pull/137194 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits