https://github.com/tblah updated https://github.com/llvm/llvm-project/pull/137205
>From 4e97789c831bbfbfe85a9df7420df65de036b09a Mon Sep 17 00:00:00 2001 From: Tom Eccles <tom.ecc...@arm.com> Date: Tue, 15 Apr 2025 15:40:39 +0000 Subject: [PATCH] [mlir][OpenMP] Convert omp.cancellation_point to LLVMIR This is basically identical to cancel except without the if clause. taskgroup will be implemented in a followup PR. --- .../llvm/Frontend/OpenMP/OMPIRBuilder.h | 10 + llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 51 +++++ .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 37 +++- .../LLVMIR/openmp-cancellation-point.mlir | 188 ++++++++++++++++++ mlir/test/Target/LLVMIR/openmp-todo.mlir | 16 +- 5 files changed, 293 insertions(+), 9 deletions(-) create mode 100644 mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index 10d69e561a987..14ad8629537f7 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -686,6 +686,16 @@ class OpenMPIRBuilder { Value *IfCondition, omp::Directive CanceledDirective); + /// Generator for '#omp cancellation point' + /// + /// \param Loc The location where the directive was encountered. + /// \param CanceledDirective The kind of directive that is cancled. + /// + /// \returns The insertion point after the barrier. + InsertPointOrErrorTy + createCancellationPoint(const LocationDescription &Loc, + omp::Directive CanceledDirective); + /// Generator for '#omp parallel' /// /// \param Loc The insert and source location description. diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 3f19088e6c73d..06aa61adcd739 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -1118,6 +1118,57 @@ OpenMPIRBuilder::createCancel(const LocationDescription &Loc, return Builder.saveIP(); } +OpenMPIRBuilder::InsertPointOrErrorTy +OpenMPIRBuilder::createCancellationPoint(const LocationDescription &Loc, + omp::Directive CanceledDirective) { + if (!updateToLocation(Loc)) + return Loc.IP; + + // LLVM utilities like blocks with terminators. + auto *UI = Builder.CreateUnreachable(); + Builder.SetInsertPoint(UI); + + Value *CancelKind = nullptr; + switch (CanceledDirective) { +#define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value) \ + case DirectiveEnum: \ + CancelKind = Builder.getInt32(Value); \ + break; +#include "llvm/Frontend/OpenMP/OMPKinds.def" + default: + llvm_unreachable("Unknown cancel kind!"); + } + + uint32_t SrcLocStrSize; + Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize); + Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize); + Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind}; + Value *Result = Builder.CreateCall( + getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_cancellationpoint), Args); + auto ExitCB = [this, CanceledDirective, Loc](InsertPointTy IP) -> Error { + if (CanceledDirective == OMPD_parallel) { + IRBuilder<>::InsertPointGuard IPG(Builder); + Builder.restoreIP(IP); + return createBarrier(LocationDescription(Builder.saveIP(), Loc.DL), + omp::Directive::OMPD_unknown, + /* ForceSimpleCall */ false, + /* CheckCancelFlag */ false) + .takeError(); + } + return Error::success(); + }; + + // The actual cancel logic is shared with others, e.g., cancel_barriers. + if (Error Err = emitCancelationCheckImpl(Result, CanceledDirective, ExitCB)) + return Err; + + // Update the insertion point and remove the terminator we introduced. + Builder.SetInsertPoint(UI->getParent()); + UI->eraseFromParent(); + + return Builder.saveIP(); +} + OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetKernel( const LocationDescription &Loc, InsertPointTy AllocaIP, Value *&Return, Value *Ident, Value *DeviceID, Value *NumTeams, Value *NumThreads, diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 9d181f12bc773..228c767699d72 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -255,6 +255,9 @@ static LogicalResult checkImplementationStatus(Operation &op) { LogicalResult result = success(); llvm::TypeSwitch<Operation &>(op) .Case([&](omp::CancelOp op) { checkCancelDirective(op, result); }) + .Case([&](omp::CancellationPointOp op) { + checkCancelDirective(op, result); + }) .Case([&](omp::DistributeOp op) { checkAllocate(op, result); checkDistSchedule(op, result); @@ -1589,11 +1592,12 @@ cleanupPrivateVars(llvm::IRBuilderBase &builder, /// Returns true if the construct contains omp.cancel or omp.cancellation_point static bool constructIsCancellable(Operation *op) { - // omp.cancel must be "closely nested" so it will be visible and not inside of - // funcion calls. This is enforced by the verifier. + // omp.cancel and omp.cancellation_point must be "closely nested" so they will + // be visible and not inside of funcion calls. This is enforced by the + // verifier. return op ->walk([](Operation *child) { - if (mlir::isa<omp::CancelOp>(child)) + if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child)) return WalkResult::interrupt(); return WalkResult::advance(); }) @@ -3089,6 +3093,30 @@ convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, return success(); } +static LogicalResult +convertOmpCancellationPoint(omp::CancellationPointOp op, + llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + + if (failed(checkImplementationStatus(*op.getOperation()))) + return failure(); + + llvm::omp::Directive cancelledDirective = + convertCancellationConstructType(op.getCancelDirective()); + + llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = + ompBuilder->createCancellationPoint(ompLoc, cancelledDirective); + + if (failed(handleError(afterIP, *op.getOperation()))) + return failure(); + + builder.restoreIP(afterIP.get()); + + return success(); +} + /// Converts an OpenMP Threadprivate operation into LLVM IR using /// OpenMPIRBuilder. static LogicalResult @@ -5522,6 +5550,9 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, .Case([&](omp::CancelOp op) { return convertOmpCancel(op, builder, moduleTranslation); }) + .Case([&](omp::CancellationPointOp op) { + return convertOmpCancellationPoint(op, builder, moduleTranslation); + }) .Case([&](omp::SectionsOp) { return convertOmpSections(*op, builder, moduleTranslation); }) diff --git a/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir b/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir new file mode 100644 index 0000000000000..bbb313c113567 --- /dev/null +++ b/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir @@ -0,0 +1,188 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +llvm.func @cancellation_point_parallel() { + omp.parallel { + omp.cancellation_point cancellation_construct_type(parallel) + omp.terminator + } + llvm.return +} +// CHECK-LABEL: define internal void @cancellation_point_parallel..omp_par +// CHECK: omp.par.entry: +// CHECK: %[[VAL_5:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_6:.*]] = load i32, ptr %[[VAL_7:.*]], align 4 +// CHECK: store i32 %[[VAL_6]], ptr %[[VAL_5]], align 4 +// CHECK: %[[VAL_8:.*]] = load i32, ptr %[[VAL_5]], align 4 +// CHECK: br label %[[VAL_9:.*]] +// CHECK: omp.region.after_alloca: ; preds = %[[VAL_10:.*]] +// CHECK: br label %[[VAL_11:.*]] +// CHECK: omp.par.region: ; preds = %[[VAL_9]] +// CHECK: br label %[[VAL_12:.*]] +// CHECK: omp.par.region1: ; preds = %[[VAL_11]] +// CHECK: %[[VAL_13:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) +// CHECK: %[[VAL_14:.*]] = call i32 @__kmpc_cancellationpoint(ptr @1, i32 %[[VAL_13]], i32 1) +// CHECK: %[[VAL_15:.*]] = icmp eq i32 %[[VAL_14]], 0 +// CHECK: br i1 %[[VAL_15]], label %[[VAL_16:.*]], label %[[VAL_17:.*]] +// CHECK: omp.par.region1.cncl: ; preds = %[[VAL_12]] +// CHECK: %[[VAL_18:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) +// CHECK: %[[VAL_19:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_18]]) +// CHECK: br label %[[VAL_20:.*]] +// CHECK: omp.par.region1.split: ; preds = %[[VAL_12]] +// CHECK: br label %[[VAL_21:.*]] +// CHECK: omp.region.cont: ; preds = %[[VAL_16]] +// CHECK: br label %[[VAL_22:.*]] +// CHECK: omp.par.pre_finalize: ; preds = %[[VAL_21]] +// CHECK: br label %[[VAL_20]] +// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_22]], %[[VAL_17]] +// CHECK: ret void + +llvm.func @cancellation_point_sections() { + omp.sections { + omp.section { + omp.cancellation_point cancellation_construct_type(sections) + omp.terminator + } + omp.terminator + } + llvm.return +} +// CHECK-LABEL: define void @cancellation_point_sections +// CHECK: %[[VAL_23:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_24:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_25:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_26:.*]] = alloca i32, align 4 +// CHECK: br label %[[VAL_27:.*]] +// CHECK: entry: ; preds = %[[VAL_28:.*]] +// CHECK: br label %[[VAL_29:.*]] +// CHECK: omp_section_loop.preheader: ; preds = %[[VAL_27]] +// CHECK: store i32 0, ptr %[[VAL_24]], align 4 +// CHECK: store i32 0, ptr %[[VAL_25]], align 4 +// CHECK: store i32 1, ptr %[[VAL_26]], align 4 +// CHECK: %[[VAL_30:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) +// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %[[VAL_30]], i32 34, ptr %[[VAL_23]], ptr %[[VAL_24]], ptr %[[VAL_25]], ptr %[[VAL_26]], i32 1, i32 0) +// CHECK: %[[VAL_31:.*]] = load i32, ptr %[[VAL_24]], align 4 +// CHECK: %[[VAL_32:.*]] = load i32, ptr %[[VAL_25]], align 4 +// CHECK: %[[VAL_33:.*]] = sub i32 %[[VAL_32]], %[[VAL_31]] +// CHECK: %[[VAL_34:.*]] = add i32 %[[VAL_33]], 1 +// CHECK: br label %[[VAL_35:.*]] +// CHECK: omp_section_loop.header: ; preds = %[[VAL_36:.*]], %[[VAL_29]] +// CHECK: %[[VAL_37:.*]] = phi i32 [ 0, %[[VAL_29]] ], [ %[[VAL_38:.*]], %[[VAL_36]] ] +// CHECK: br label %[[VAL_39:.*]] +// CHECK: omp_section_loop.cond: ; preds = %[[VAL_35]] +// CHECK: %[[VAL_40:.*]] = icmp ult i32 %[[VAL_37]], %[[VAL_34]] +// CHECK: br i1 %[[VAL_40]], label %[[VAL_41:.*]], label %[[VAL_42:.*]] +// CHECK: omp_section_loop.body: ; preds = %[[VAL_39]] +// CHECK: %[[VAL_43:.*]] = add i32 %[[VAL_37]], %[[VAL_31]] +// CHECK: %[[VAL_44:.*]] = mul i32 %[[VAL_43]], 1 +// CHECK: %[[VAL_45:.*]] = add i32 %[[VAL_44]], 0 +// CHECK: switch i32 %[[VAL_45]], label %[[VAL_46:.*]] [ +// CHECK: i32 0, label %[[VAL_47:.*]] +// CHECK: ] +// CHECK: omp_section_loop.body.case: ; preds = %[[VAL_41]] +// CHECK: br label %[[VAL_48:.*]] +// CHECK: omp.section.region: ; preds = %[[VAL_47]] +// CHECK: %[[VAL_49:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) +// CHECK: %[[VAL_50:.*]] = call i32 @__kmpc_cancellationpoint(ptr @1, i32 %[[VAL_49]], i32 3) +// CHECK: %[[VAL_51:.*]] = icmp eq i32 %[[VAL_50]], 0 +// CHECK: br i1 %[[VAL_51]], label %[[VAL_52:.*]], label %[[VAL_53:.*]] +// CHECK: omp.section.region.split: ; preds = %[[VAL_48]] +// CHECK: br label %[[VAL_54:.*]] +// CHECK: omp.region.cont: ; preds = %[[VAL_52]] +// CHECK: br label %[[VAL_46]] +// CHECK: omp_section_loop.body.sections.after: ; preds = %[[VAL_54]], %[[VAL_41]] +// CHECK: br label %[[VAL_36]] +// CHECK: omp_section_loop.inc: ; preds = %[[VAL_46]] +// CHECK: %[[VAL_38]] = add nuw i32 %[[VAL_37]], 1 +// CHECK: br label %[[VAL_35]] +// CHECK: omp_section_loop.exit: ; preds = %[[VAL_53]], %[[VAL_39]] +// CHECK: call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_30]]) +// CHECK: %[[VAL_55:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) +// CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_55]]) +// CHECK: br label %[[VAL_56:.*]] +// CHECK: omp_section_loop.after: ; preds = %[[VAL_42]] +// CHECK: br label %[[VAL_57:.*]] +// CHECK: omp_section_loop.aftersections.fini: ; preds = %[[VAL_56]] +// CHECK: ret void +// CHECK: omp.section.region.cncl: ; preds = %[[VAL_48]] +// CHECK: br label %[[VAL_42]] + +llvm.func @cancellation_point_wsloop(%lb : i32, %ub : i32, %step : i32) { + omp.wsloop { + omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { + omp.cancellation_point cancellation_construct_type(loop) + omp.yield + } + } + llvm.return +} +// CHECK-LABEL: define void @cancellation_point_wsloop +// CHECK: %[[VAL_58:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_59:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_60:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_61:.*]] = alloca i32, align 4 +// CHECK: br label %[[VAL_62:.*]] +// CHECK: omp.region.after_alloca: ; preds = %[[VAL_63:.*]] +// CHECK: br label %[[VAL_64:.*]] +// CHECK: entry: ; preds = %[[VAL_62]] +// CHECK: br label %[[VAL_65:.*]] +// CHECK: omp.wsloop.region: ; preds = %[[VAL_64]] +// CHECK: %[[VAL_66:.*]] = icmp slt i32 %[[VAL_67:.*]], 0 +// CHECK: %[[VAL_68:.*]] = sub i32 0, %[[VAL_67]] +// CHECK: %[[VAL_69:.*]] = select i1 %[[VAL_66]], i32 %[[VAL_68]], i32 %[[VAL_67]] +// CHECK: %[[VAL_70:.*]] = select i1 %[[VAL_66]], i32 %[[VAL_71:.*]], i32 %[[VAL_72:.*]] +// CHECK: %[[VAL_73:.*]] = select i1 %[[VAL_66]], i32 %[[VAL_72]], i32 %[[VAL_71]] +// CHECK: %[[VAL_74:.*]] = sub nsw i32 %[[VAL_73]], %[[VAL_70]] +// CHECK: %[[VAL_75:.*]] = icmp sle i32 %[[VAL_73]], %[[VAL_70]] +// CHECK: %[[VAL_76:.*]] = sub i32 %[[VAL_74]], 1 +// CHECK: %[[VAL_77:.*]] = udiv i32 %[[VAL_76]], %[[VAL_69]] +// CHECK: %[[VAL_78:.*]] = add i32 %[[VAL_77]], 1 +// CHECK: %[[VAL_79:.*]] = icmp ule i32 %[[VAL_74]], %[[VAL_69]] +// CHECK: %[[VAL_80:.*]] = select i1 %[[VAL_79]], i32 1, i32 %[[VAL_78]] +// CHECK: %[[VAL_81:.*]] = select i1 %[[VAL_75]], i32 0, i32 %[[VAL_80]] +// CHECK: br label %[[VAL_82:.*]] +// CHECK: omp_loop.preheader: ; preds = %[[VAL_65]] +// CHECK: store i32 0, ptr %[[VAL_59]], align 4 +// CHECK: %[[VAL_83:.*]] = sub i32 %[[VAL_81]], 1 +// CHECK: store i32 %[[VAL_83]], ptr %[[VAL_60]], align 4 +// CHECK: store i32 1, ptr %[[VAL_61]], align 4 +// CHECK: %[[VAL_84:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) +// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %[[VAL_84]], i32 34, ptr %[[VAL_58]], ptr %[[VAL_59]], ptr %[[VAL_60]], ptr %[[VAL_61]], i32 1, i32 0) +// CHECK: %[[VAL_85:.*]] = load i32, ptr %[[VAL_59]], align 4 +// CHECK: %[[VAL_86:.*]] = load i32, ptr %[[VAL_60]], align 4 +// CHECK: %[[VAL_87:.*]] = sub i32 %[[VAL_86]], %[[VAL_85]] +// CHECK: %[[VAL_88:.*]] = add i32 %[[VAL_87]], 1 +// CHECK: br label %[[VAL_89:.*]] +// CHECK: omp_loop.header: ; preds = %[[VAL_90:.*]], %[[VAL_82]] +// CHECK: %[[VAL_91:.*]] = phi i32 [ 0, %[[VAL_82]] ], [ %[[VAL_92:.*]], %[[VAL_90]] ] +// CHECK: br label %[[VAL_93:.*]] +// CHECK: omp_loop.cond: ; preds = %[[VAL_89]] +// CHECK: %[[VAL_94:.*]] = icmp ult i32 %[[VAL_91]], %[[VAL_88]] +// CHECK: br i1 %[[VAL_94]], label %[[VAL_95:.*]], label %[[VAL_96:.*]] +// CHECK: omp_loop.body: ; preds = %[[VAL_93]] +// CHECK: %[[VAL_97:.*]] = add i32 %[[VAL_91]], %[[VAL_85]] +// CHECK: %[[VAL_98:.*]] = mul i32 %[[VAL_97]], %[[VAL_67]] +// CHECK: %[[VAL_99:.*]] = add i32 %[[VAL_98]], %[[VAL_72]] +// CHECK: br label %[[VAL_100:.*]] +// CHECK: omp.loop_nest.region: ; preds = %[[VAL_95]] +// CHECK: %[[VAL_101:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) +// CHECK: %[[VAL_102:.*]] = call i32 @__kmpc_cancellationpoint(ptr @1, i32 %[[VAL_101]], i32 2) +// CHECK: %[[VAL_103:.*]] = icmp eq i32 %[[VAL_102]], 0 +// CHECK: br i1 %[[VAL_103]], label %[[VAL_104:.*]], label %[[VAL_105:.*]] +// CHECK: omp.loop_nest.region.split: ; preds = %[[VAL_100]] +// CHECK: br label %[[VAL_106:.*]] +// CHECK: omp.region.cont1: ; preds = %[[VAL_104]] +// CHECK: br label %[[VAL_90]] +// CHECK: omp_loop.inc: ; preds = %[[VAL_106]] +// CHECK: %[[VAL_92]] = add nuw i32 %[[VAL_91]], 1 +// CHECK: br label %[[VAL_89]] +// CHECK: omp_loop.exit: ; preds = %[[VAL_105]], %[[VAL_93]] +// CHECK: call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_84]]) +// CHECK: %[[VAL_107:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) +// CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_107]]) +// CHECK: br label %[[VAL_108:.*]] +// CHECK: omp_loop.after: ; preds = %[[VAL_96]] +// CHECK: br label %[[VAL_109:.*]] +// CHECK: omp.region.cont: ; preds = %[[VAL_108]] +// CHECK: ret void +// CHECK: omp.loop_nest.region.cncl: ; preds = %[[VAL_100]] +// CHECK: br label %[[VAL_96]] diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir index ed355096b702e..789c0ad9ebb48 100644 --- a/mlir/test/Target/LLVMIR/openmp-todo.mlir +++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir @@ -43,12 +43,16 @@ llvm.func @cancel_taskgroup() { // ----- -llvm.func @cancellation_point() { - // expected-error@below {{LLVM Translation failed for operation: omp.parallel}} - omp.parallel { - // expected-error@below {{not yet implemented: omp.cancellation_point}} - // expected-error@below {{LLVM Translation failed for operation: omp.cancellation_point}} - omp.cancellation_point cancellation_construct_type(parallel) +llvm.func @cancellation_point_taskgroup() { + // expected-error@below {{LLVM Translation failed for operation: omp.taskgroup}} + omp.taskgroup { + // expected-error@below {{LLVM Translation failed for operation: omp.task}} + omp.task { + // expected-error@below {{not yet implemented: Unhandled clause cancel directive construct type not yet supported in omp.cancellation_point operation}} + // expected-error@below {{LLVM Translation failed for operation: omp.cancellation_point}} + omp.cancellation_point cancellation_construct_type(taskgroup) + omp.terminator + } omp.terminator } llvm.return _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits