https://github.com/skatrak created https://github.com/llvm/llvm-project/pull/127821
This patch implements MLIR to LLVM IR translation of host-evaluated loop bounds, completing initial support for `target teams distribute parallel do [simd]` and `target teams distribute [simd]`. >From 33409d2b52bfb4c69f67bbde001de5ce48feb073 Mon Sep 17 00:00:00 2001 From: Sergio Afonso <safon...@amd.com> Date: Wed, 19 Feb 2025 14:41:12 +0000 Subject: [PATCH] [MLIR][OpenMP] Support target SPMD This patch implements MLIR to LLVM IR translation of host-evaluated loop bounds, completing initial support for `target teams distribute parallel do [simd]` and `target teams distribute [simd]`. --- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 83 ++++++++++++---- .../Target/LLVMIR/openmp-target-spmd.mlir | 96 +++++++++++++++++++ mlir/test/Target/LLVMIR/openmp-todo.mlir | 24 ----- 3 files changed, 159 insertions(+), 44 deletions(-) create mode 100644 mlir/test/Target/LLVMIR/openmp-target-spmd.mlir diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 7e8a9bdb5b133..93a88c89162d6 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -176,15 +176,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { if (op.getHint()) op.emitWarning("hint clause discarded"); }; - auto checkHostEval = [](auto op, LogicalResult &result) { - // Host evaluated clauses are supported, except for loop bounds. - for (BlockArgument arg : - cast<omp::BlockArgOpenMPOpInterface>(*op).getHostEvalBlockArgs()) - for (Operation *user : arg.getUsers()) - if (isa<omp::LoopNestOp>(user)) - result = op.emitError("not yet implemented: host evaluation of loop " - "bounds in omp.target operation"); - }; auto checkInReduction = [&todo](auto op, LogicalResult &result) { if (!op.getInReductionVars().empty() || op.getInReductionByref() || op.getInReductionSyms()) @@ -321,7 +312,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { checkBare(op, result); checkDevice(op, result); checkHasDeviceAddr(op, result); - checkHostEval(op, result); checkInReduction(op, result); checkIsDevicePtr(op, result); checkPrivate(op, result); @@ -4054,9 +4044,13 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, /// /// Loop bounds and steps are only optionally populated, if output vectors are /// provided. -static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, - Value &numTeamsLower, Value &numTeamsUpper, - Value &threadLimit) { +static void +extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, + Value &numTeamsLower, Value &numTeamsUpper, + Value &threadLimit, + llvm::SmallVectorImpl<Value> *lowerBounds = nullptr, + llvm::SmallVectorImpl<Value> *upperBounds = nullptr, + llvm::SmallVectorImpl<Value> *steps = nullptr) { auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp); for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(), blockArgIface.getHostEvalBlockArgs())) { @@ -4081,11 +4075,26 @@ static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, llvm_unreachable("unsupported host_eval use"); }) .Case([&](omp::LoopNestOp loopOp) { - // TODO: Extract bounds and step values. Currently, this cannot be - // reached because translation would have been stopped earlier as a - // result of `checkImplementationStatus` detecting and reporting - // this situation. - llvm_unreachable("unsupported host_eval use"); + auto processBounds = + [&](OperandRange opBounds, + llvm::SmallVectorImpl<Value> *outBounds) -> bool { + bool found = false; + for (auto [i, lb] : llvm::enumerate(opBounds)) { + if (lb == blockArg) { + found = true; + if (outBounds) + (*outBounds)[i] = hostEvalVar; + } + } + return found; + }; + bool found = + processBounds(loopOp.getLoopLowerBounds(), lowerBounds); + found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) || + found; + found = processBounds(loopOp.getLoopSteps(), steps) || found; + if (!found) + llvm_unreachable("unsupported host_eval use"); }) .Default([](Operation *) { llvm_unreachable("unsupported host_eval use"); @@ -4222,6 +4231,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, combinedMaxThreadsVal = maxThreadsVal; // Update kernel bounds structure for the `OpenMPIRBuilder` to use. + attrs.ExecFlags = targetOp.getKernelExecFlags(); attrs.MinTeams = minTeamsVal; attrs.MaxTeams.front() = maxTeamsVal; attrs.MinThreads = 1; @@ -4239,9 +4249,15 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) { + omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>( + targetOp.getInnermostCapturedOmpOp()); + unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0; + Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit; + llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops), + steps(numLoops); extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper, - teamsThreadLimit); + teamsThreadLimit, &lowerBounds, &upperBounds, &steps); // TODO: Handle constant 'if' clauses. if (Value targetThreadLimit = targetOp.getThreadLimit()) @@ -4261,7 +4277,34 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, if (numThreads) attrs.MaxThreads = moduleTranslation.lookupValue(numThreads); - // TODO: Populate attrs.LoopTripCount if it is target SPMD. + if (targetOp.getKernelExecFlags() != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) { + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + attrs.LoopTripCount = nullptr; + + // To calculate the trip count, we multiply together the trip counts of + // every collapsed canonical loop. We don't need to create the loop nests + // here, since we're only interested in the trip count. + for (auto [loopLower, loopUpper, loopStep] : + llvm::zip_equal(lowerBounds, upperBounds, steps)) { + llvm::Value *lowerBound = moduleTranslation.lookupValue(loopLower); + llvm::Value *upperBound = moduleTranslation.lookupValue(loopUpper); + llvm::Value *step = moduleTranslation.lookupValue(loopStep); + + llvm::OpenMPIRBuilder::LocationDescription loc(builder); + llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount( + loc, lowerBound, upperBound, step, /*IsSigned=*/true, + loopOp.getLoopInclusive()); + + if (!attrs.LoopTripCount) { + attrs.LoopTripCount = tripCount; + continue; + } + + // TODO: Enable UndefinedSanitizer to diagnose an overflow here. + attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount, + {}, /*HasNUW=*/true); + } + } } static LogicalResult diff --git a/mlir/test/Target/LLVMIR/openmp-target-spmd.mlir b/mlir/test/Target/LLVMIR/openmp-target-spmd.mlir new file mode 100644 index 0000000000000..7930554cbe11a --- /dev/null +++ b/mlir/test/Target/LLVMIR/openmp-target-spmd.mlir @@ -0,0 +1,96 @@ +// RUN: split-file %s %t +// RUN: mlir-translate -mlir-to-llvmir %t/host.mlir | FileCheck %s --check-prefix=HOST +// RUN: mlir-translate -mlir-to-llvmir %t/device.mlir | FileCheck %s --check-prefix=DEVICE + +//--- host.mlir + +module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} { + llvm.func @main(%x : i32) { + omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) { + omp.teams { + omp.parallel { + omp.distribute { + omp.wsloop { + omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { + omp.yield + } + } {omp.composite} + } {omp.composite} + omp.terminator + } {omp.composite} + omp.terminator + } + omp.terminator + } + llvm.return + } +} + +// HOST-LABEL: define void @main +// HOST: %omp_loop.tripcount = {{.*}} +// HOST-NEXT: br label %[[ENTRY:.*]] +// HOST: [[ENTRY]]: +// HOST-NEXT: %[[TRIPCOUNT:.*]] = zext i32 %omp_loop.tripcount to i64 +// HOST: %[[TRIPCOUNT_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KARGS:.*]], i32 0, i32 8 +// HOST-NEXT: store i64 %[[TRIPCOUNT]], ptr %[[TRIPCOUNT_KARG]] +// HOST: %[[RESULT:.*]] = call i32 @__tgt_target_kernel({{.*}}, ptr %[[KARGS]]) +// HOST-NEXT: %[[CMP:.*]] = icmp ne i32 %[[RESULT]], 0 +// HOST-NEXT: br i1 %[[CMP]], label %[[OFFLOAD_FAILED:.*]], label %{{.*}} +// HOST: [[OFFLOAD_FAILED]]: +// HOST: call void @[[TARGET_OUTLINE:.*]]({{.*}}) + +// HOST: define internal void @[[TARGET_OUTLINE]] +// HOST: call void{{.*}}@__kmpc_fork_teams({{.*}}, ptr @[[TEAMS_OUTLINE:.*]], {{.*}}) + +// HOST: define internal void @[[TEAMS_OUTLINE]] +// HOST: call void{{.*}}@__kmpc_fork_call({{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], {{.*}}) + +// HOST: define internal void @[[PARALLEL_OUTLINE]] +// HOST: call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}}) + +// HOST: define internal void @[[DISTRIBUTE_OUTLINE]] +// HOST: call void @__kmpc_dist_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 34, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}}) + +//--- device.mlir + +module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_target_device = true, omp.is_gpu = true} { + llvm.func @main(%x : i32) { + omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) { + omp.teams { + omp.parallel { + omp.distribute { + omp.wsloop { + omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { + omp.yield + } + } {omp.composite} + } {omp.composite} + omp.terminator + } {omp.composite} + omp.terminator + } + omp.terminator + } + llvm.return + } +} + +// DEVICE: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 2 +// DEVICE: @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata" +// DEVICE: @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy { +// DEVICE-SAME: %struct.ConfigurationEnvironmentTy { i8 0, i8 1, i8 [[EXEC_MODE:2]], {{.*}}}, +// DEVICE-SAME: ptr @{{.*}}, ptr @{{.*}} } + +// DEVICE: define weak_odr protected amdgpu_kernel void @[[KERNEL_NAME]]({{.*}}) +// DEVICE: %{{.*}} = call i32 @__kmpc_target_init(ptr @[[KERNEL_NAME]]_kernel_environment, {{.*}}) +// DEVICE: call void @[[TARGET_OUTLINE:.*]]({{.*}}) +// DEVICE: call void @__kmpc_target_deinit() + +// DEVICE: define internal void @[[TARGET_OUTLINE]]({{.*}}) +// DEVICE: call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}}) + +// DEVICE: define internal void @[[PARALLEL_OUTLINE]]({{.*}}) +// DEVICE: call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}}) + +// DEVICE: define internal void @[[DISTRIBUTE_OUTLINE]]({{.*}}) +// DEVICE: call void @__kmpc_distribute_for_static_loop{{.*}}({{.*}}) diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir index d1c745af9bff5..f907bb3f94a2a 100644 --- a/mlir/test/Target/LLVMIR/openmp-todo.mlir +++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir @@ -319,30 +319,6 @@ llvm.func @target_has_device_addr(%x : !llvm.ptr) { // ----- -llvm.func @target_host_eval(%x : i32) { - // expected-error@below {{not yet implemented: host evaluation of loop bounds in omp.target operation}} - // expected-error@below {{LLVM Translation failed for operation: omp.target}} - omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) { - omp.teams { - omp.parallel { - omp.distribute { - omp.wsloop { - omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { - omp.yield - } - } {omp.composite} - } {omp.composite} - omp.terminator - } {omp.composite} - omp.terminator - } - omp.terminator - } - llvm.return -} - -// ----- - omp.declare_reduction @add_f32 : f32 init { ^bb0(%arg: f32): _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits