https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/171825
>From e3e8ed5a2bf6d33716efb6741d03891bfe3f6947 Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Fri, 16 Jan 2026 08:48:30 +0530 Subject: [PATCH 1/9] Update num_teams to have just the list and no dims(N) syntax --- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 2 +- mlir/test/Dialect/OpenMP/ops.mlir | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 67ff9023a38da..e5b98024dbed1 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -4588,7 +4588,7 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op, p << " : " << upperBoundType; } else { // Upper only: to upper : type - p << " to "; + p << "to "; p.printOperand(upperBound); p << " : " << upperBoundType; } diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 89c7e5fd48bd9..d28f31c8328b2 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -1103,7 +1103,7 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32, omp.terminator } - // CHECK: omp.teams num_teams( to %{{.+}} : i32) + // CHECK: omp.teams num_teams(to %{{.+}} : i32) omp.teams num_teams(to %ub : i32) { // CHECK: omp.terminator omp.terminator @@ -3084,7 +3084,7 @@ func.func @omp_target_private_with_map_idx(%map1: memref<?xi32>, %map2: memref<? func.func @omp_target_host_eval(%x : i32) { // CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) { - // CHECK: omp.teams num_teams( to %[[HOST_ARG]] : i32) + // CHECK: omp.teams num_teams(to %[[HOST_ARG]] : i32) // CHECK-SAME: thread_limit(%[[HOST_ARG]] : i32) omp.target host_eval(%x -> %arg0 : i32) { omp.teams num_teams( to %arg0 : i32) thread_limit(%arg0 : i32) { >From 858f03936cb9423012d8a859cf3b016d67355f1f Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Thu, 11 Dec 2025 13:35:05 +0530 Subject: [PATCH 2/9] [OpenMP][MLIR] Add thread_limit with dims modifier support --- .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 16 +- .../mlir/Dialect/OpenMP/OpenMPClauses.td | 29 +++- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 69 ++++++++- mlir/test/Dialect/OpenMP/invalid.mlir | 139 +++++++++++++++++- mlir/test/Dialect/OpenMP/ops.mlir | 8 +- 5 files changed, 249 insertions(+), 12 deletions(-) diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index 7b61539984232..a3b9e5c76bdd2 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -766,6 +766,7 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp, targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), innerMapInfos, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(), + targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(), targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(), newTargetOp.getRegion().begin()); @@ -1485,8 +1486,9 @@ genPreTargetOp(omp::TargetOp targetOp, SmallVector<Value> &preMapOperands, targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), preMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), - targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(), - targetOp.getPrivateMapsAttr()); + targetOp.getPrivateNeedsBarrierAttr(), + targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(), + targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); auto *preTargetBlock = rewriter.createBlock( &preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {}); IRMapping preMapping; @@ -1575,8 +1577,9 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands, targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), - targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(), - targetOp.getPrivateMapsAttr()); + targetOp.getPrivateNeedsBarrierAttr(), + targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(), + targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); auto *isolatedTargetBlock = rewriter.createBlock(&isolatedTargetOp.getRegion(), isolatedTargetOp.getRegion().begin(), {}, {}); @@ -1655,8 +1658,9 @@ static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp, targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), - targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(), - targetOp.getPrivateMapsAttr()); + targetOp.getPrivateNeedsBarrierAttr(), + targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(), + targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); // Create the block for postTargetOp auto *postTargetBlock = rewriter.createBlock( &postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {}); diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index d4640f254ed1f..d2ebd29229e84 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -1452,16 +1452,43 @@ class OpenMP_ThreadLimitClauseSkip< > : OpenMP_Clause<traits, arguments, assemblyFormat, description, extraClassDeclaration> { let arguments = (ins + ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$thread_limit_num_dims, + Variadic<AnyInteger>:$thread_limit_dims_values, Optional<AnyInteger>:$thread_limit ); let optAssemblyFormat = [{ - `thread_limit` `(` $thread_limit `:` type($thread_limit) `)` + `thread_limit` `(` custom<ThreadLimitClause>( + $thread_limit_num_dims, $thread_limit_dims_values, type($thread_limit_dims_values), + $thread_limit, type($thread_limit) + ) `)` }]; let description = [{ The optional `thread_limit` specifies the limit on the number of threads. }]; + + let extraClassDeclaration = [{ + /// Returns true if the dims modifier is explicitly present + bool hasThreadLimitDimsModifier() { + return getThreadLimitNumDims().has_value() && getThreadLimitNumDims().value(); + } + + /// Returns the number of dimensions specified by dims modifier + unsigned getThreadLimitDimsCount() { + if (!hasThreadLimitDimsModifier()) + return 1; + return static_cast<unsigned>(*getThreadLimitNumDims()); + } + + /// Returns the value for a specific dimension index + /// Index must be less than getThreadLimitDimsCount() + ::mlir::Value getThreadLimitDimensionValue(unsigned index) { + assert(index < getThreadLimitDimsCount() && + "Thread limit dims index out of bounds"); + return getThreadLimitDimsValues()[index]; + } + }]; } def OpenMP_ThreadLimitClause : OpenMP_ThreadLimitClauseSkip<>; diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index e5b98024dbed1..e83419492d28e 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -2210,10 +2210,30 @@ void TargetOp::build(OpBuilder &builder, OperationState &state, /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars, clauses.mapVars, clauses.nowait, clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), - clauses.privateNeedsBarrier, clauses.threadLimit, + clauses.privateNeedsBarrier, clauses.threadLimitNumDims, + clauses.threadLimitDimsValues, clauses.threadLimit, /*private_maps=*/nullptr); } +// helper for thread_limit clause restrictions +static LogicalResult +verifyThreadLimitClause(Operation *op, + std::optional<IntegerAttr> threadLimitNumDims, + OperandRange threadLimitDimsValues, Value threadLimit) { + bool hasDimsModifier = + threadLimitNumDims.has_value() && threadLimitNumDims.value(); + + if (hasDimsModifier && threadLimit) { + return op->emitError("thread_limit with dims modifier cannot be used " + "together with number of threads"); + } + + if (failed(verifyDimsModifier(op, threadLimitNumDims, threadLimitDimsValues))) + return failure(); + + return success(); +} + LogicalResult TargetOp::verify() { if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars()))) return failure(); @@ -2225,6 +2245,11 @@ LogicalResult TargetOp::verify() { if (failed(verifyMapClause(*this, getMapVars()))) return failure(); + if (failed(verifyThreadLimitClause(*this, getThreadLimitNumDimsAttr(), + getThreadLimitDimsValues(), + getThreadLimit()))) + return failure(); + return verifyPrivateVarsMapping(*this); } @@ -2687,6 +2712,12 @@ LogicalResult TeamsOp::verify() { return emitError( "expected equal sizes for allocate and allocator variables"); + // Check for thread_limit clause restrictions + if (failed(verifyThreadLimitClause(*this, getThreadLimitNumDimsAttr(), + getThreadLimitDimsValues(), + getThreadLimit()))) + return failure(); + return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(), getReductionByref()); } @@ -4595,6 +4626,42 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op, } } +//===----------------------------------------------------------------------===// +// Parser and printer for thread_limit clause +//===----------------------------------------------------------------------===// +static ParseResult +parseThreadLimitClause(OpAsmParser &parser, IntegerAttr &dimsAttr, + SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, + SmallVectorImpl<Type> &types, + std::optional<OpAsmParser::UnresolvedOperand> &bounds, + Type &boundsType) { + if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) { + return success(); + } + + OpAsmParser::UnresolvedOperand boundsOperand; + if (parser.parseOperand(boundsOperand) || parser.parseColon() || + parser.parseType(boundsType)) { + return failure(); + } + bounds = boundsOperand; + return success(); +} + +static void printThreadLimitClause(OpAsmPrinter &p, Operation *op, + IntegerAttr dimsAttr, OperandRange values, + TypeRange types, Value bounds, + Type boundsType) { + if (!values.empty()) { + // Multidimensional: dims(N): values : type + printDimsModifierWithValues(p, dimsAttr, values, types); + } else if (bounds) { + // Both bounds: bounds : type + p.printOperand(bounds); + p << " : " << boundsType; + } +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index bb882db73cbab..e841e65d36292 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -1438,7 +1438,7 @@ func.func @omp_teams_allocate(%data_var : memref<i32>) { // expected-error @below {{expected equal sizes for allocate and allocator variables}} "omp.teams" (%data_var) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0,0>} : (memref<i32>) -> () + }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0,0,0>} : (memref<i32>) -> () omp.terminator } return @@ -1451,7 +1451,7 @@ func.func @omp_teams_num_teams1(%lb : i32) { // expected-error @below {{expected num_teams upper bound to be defined if the lower bound is defined}} "omp.teams" (%lb) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,0,0>} : (i32) -> () + }) {operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,0,0,0>} : (i32) -> () omp.terminator } return @@ -1489,6 +1489,139 @@ func.func @omp_teams_num_teams2(%lb : i32, %ub : i16) { // ----- +func.func @omp_teams_thread_limit_dims_mismatch() { + omp.target { + %v0 = arith.constant 1 : i32 + %v1 = arith.constant 2 : i32 + // expected-error @below {{dims(3) specified but 2 values provided}} + "omp.teams" (%v0, %v1) ({ + omp.terminator + }) {thread_limit_num_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2,0>} : (i32, i32) -> () + omp.terminator + } + return +} + +// ----- + +func.func @omp_teams_thread_limit_dims_with_scalar() { + omp.target { + %v0 = arith.constant 1 : i32 + %v1 = arith.constant 2 : i32 + %tl = arith.constant 4 : i32 + // expected-error @below {{thread_limit with dims modifier cannot be used together with number of threads}} + "omp.teams" (%v0, %v1, %tl) ({ + omp.terminator + }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2,1>} : (i32, i32, i32) -> () + omp.terminator + } + return +} + +// ----- + +func.func @omp_teams_thread_limit_dims_no_values() { + omp.target { + // expected-error @below {{dims modifier requires values to be specified}} + "omp.teams" () ({ + omp.terminator + }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0>} : () -> () + omp.terminator + } + return +} + +// ----- + +func.func @omp_teams_thread_limit_values_without_dims() { + omp.target { + %v0 = arith.constant 1 : i32 + %v1 = arith.constant 2 : i32 + // expected-error @below {{dims values can only be specified with dims modifier}} + "omp.teams" (%v0, %v1) ({ + omp.terminator + }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2,0>} : (i32, i32) -> () + omp.terminator + } + return +} + +// ----- + +func.func @omp_teams_thread_limit_dims_type_mismatch() { + omp.target { + %v0 = arith.constant 1 : i32 + %v1 = arith.constant 2 : i64 + // expected-error @below {{dims modifier requires all values to have the same type}} + "omp.teams" (%v0, %v1) ({ + omp.terminator + }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2,0>} : (i32, i64) -> () + omp.terminator + } + return +} + +// ----- + +func.func @omp_target_thread_limit_dims_mismatch() { + %v0 = arith.constant 1 : i32 + %v1 = arith.constant 2 : i32 + // expected-error @below {{dims(3) specified but 2 values provided}} + "omp.target" (%v0, %v1) ({ + omp.terminator + }) {thread_limit_num_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2,0>} : (i32, i32) -> () + return +} + +// ----- + +func.func @omp_target_thread_limit_dims_with_scalar() { + %v0 = arith.constant 1 : i32 + %v1 = arith.constant 2 : i32 + %tl = arith.constant 4 : i32 + // expected-error @below {{thread_limit with dims modifier cannot be used together with number of threads}} + "omp.target" (%v0, %v1, %tl) ({ + omp.terminator + }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2,1>} : (i32, i32, i32) -> () + return +} + +// ----- + +func.func @omp_target_thread_limit_dims_no_values() { + // expected-error @below {{dims modifier requires values to be specified}} + "omp.target" () ({ + omp.terminator + }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,0,0>} : () -> () + return +} + +// ----- + +func.func @omp_target_thread_limit_values_without_dims() { + %v0 = arith.constant 1 : i32 + %v1 = arith.constant 2 : i32 + // expected-error @below {{dims values can only be specified with dims modifier}} + "omp.target" (%v0, %v1) ({ + omp.terminator + }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2,0>} : (i32, i32) -> () + return +} + +// ----- + +func.func @omp_target_thread_limit_dims_type_mismatch() { + %v0 = arith.constant 1 : i32 + %v1 = arith.constant 2 : i64 + // expected-error @below {{dims modifier requires all values to have the same type}} + "omp.target" (%v0, %v1) ({ + omp.terminator + }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2,0>} : (i32, i64) -> () + return +} + +// ----- + func.func @omp_sections(%data_var : memref<i32>) -> () { // expected-error @below {{expected equal sizes for allocate and allocator variables}} "omp.sections" (%data_var) ({ @@ -2475,7 +2608,7 @@ func.func @omp_target_depend(%data_var: memref<i32>) { // expected-error @below {{op expected as many depend values as depend variables}} "omp.target"(%data_var) ({ "omp.terminator"() : () -> () - }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> () + }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> () "func.return"() : () -> () } diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index d28f31c8328b2..39ddc5bfa4e50 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -824,7 +824,7 @@ func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %devic "omp.target"(%device, %if_cond, %num_threads) ({ // CHECK: omp.terminator omp.terminator - }) {nowait, operandSegmentSizes = array<i32: 0,0,0,1,0,0,1,0,0,0,0,1>} : ( si32, i1, i32 ) -> () + }) {nowait, operandSegmentSizes = array<i32: 0,0,0,1,0,0,1,0,0,0,0,0,1>} : ( si32, i1, i32 ) -> () // Test with optional map clause. // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(always, to) capture(ByRef) -> memref<?xi32> {name = ""} @@ -1136,6 +1136,12 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32, omp.terminator } + // CHECK: omp.teams thread_limit(dims(2): %{{.*}}, %{{.*}} : i32) + omp.teams thread_limit(dims(2): %lb, %ub : i32) { + // CHECK: omp.terminator + omp.terminator + } + // Test reduction. %c1 = arith.constant 1 : i32 %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr >From 1f1eac98b3a73d389a805ffe077a10fb08943599 Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Thu, 11 Dec 2025 18:35:55 +0530 Subject: [PATCH 3/9] update thread_limit description --- mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index d2ebd29229e84..c8d8d0003deef 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -1465,7 +1465,18 @@ class OpenMP_ThreadLimitClauseSkip< }]; let description = [{ - The optional `thread_limit` specifies the limit on the number of threads. + The `thread_limit` clause specifies the limit on the number of threads. + + With dims modifier: + - The number of dimensions is specified by the `thread_limit_num_dims` attribute. + - The values for each dimension are specified by the `thread_limit_dims_values` attribute. + - Format: `thread_limit(dims(N): values : type)` + - Example: `thread_limit(dims(2): %n, %m : i64)` + + Without dims modifier: + - The number of threads is specified by the `thread_limit`. + - Format: `thread_limit(number_of_threads : type)` + - Example: `thread_limit(%n : i64)` }]; let extraClassDeclaration = [{ >From 12da9b0afc48b841481ec35fd2de72ffbb451459 Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Wed, 17 Dec 2025 09:06:36 +0530 Subject: [PATCH 4/9] Remove separate thread_limit argument from clause --- flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 3 +- flang/lib/Lower/OpenMP/OpenMP.cpp | 16 +++--- .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 8 +-- .../mlir/Dialect/OpenMP/OpenMPClauses.td | 16 +++--- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 54 +++++++------------ .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 11 ++-- mlir/test/Dialect/OpenMP/invalid.mlir | 28 +++++----- mlir/test/Dialect/OpenMP/ops.mlir | 2 +- 8 files changed, 63 insertions(+), 75 deletions(-) diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index b923e415231d6..18bab01d94365 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -668,8 +668,9 @@ bool ClauseProcessor::processThreadLimit( lower::StatementContext &stmtCtx, mlir::omp::ThreadLimitClauseOps &result) const { if (auto *clause = findUniqueClause<omp::clause::ThreadLimit>()) { - result.threadLimit = + mlir::Value threadLimitVal = fir::getBase(converter.genExprValue(clause->v, stmtCtx)); + result.threadLimitDimsValues.push_back(threadLimitVal); return true; } return false; diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 989e370870f33..1021742b87b2f 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -102,8 +102,9 @@ class HostEvalInfo { if (ops.numThreads) vars.push_back(ops.numThreads); - if (ops.threadLimit) - vars.push_back(ops.threadLimit); + // Old spec: single value in threadLimitDimsValues + for (mlir::Value val : ops.threadLimitDimsValues) + vars.push_back(val); } /// Update \c ops, replacing all values with the corresponding block argument @@ -116,7 +117,7 @@ class HostEvalInfo { ops.loopLowerBounds.size() + ops.loopUpperBounds.size() + ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) + (ops.numTeamsUpper ? 1 : 0) + (ops.numThreads ? 1 : 0) + - (ops.threadLimit ? 1 : 0) && + ops.threadLimitDimsValues.size() && "invalid block argument list"); int argIndex = 0; for (size_t i = 0; i < ops.loopLowerBounds.size(); ++i) @@ -137,8 +138,8 @@ class HostEvalInfo { if (ops.numThreads) ops.numThreads = args[argIndex++]; - if (ops.threadLimit) - ops.threadLimit = args[argIndex++]; + for (size_t i = 0; i < ops.threadLimitDimsValues.size(); ++i) + ops.threadLimitDimsValues[i] = args[argIndex++]; } /// Update \p clauseOps and \p ivOut with the corresponding host-evaluated @@ -185,12 +186,13 @@ class HostEvalInfo { /// \returns whether an update was performed. If not, these clauses were not /// evaluated in the host device. bool apply(mlir::omp::TeamsOperands &clauseOps) { - if (!ops.numTeamsLower && !ops.numTeamsUpper && !ops.threadLimit) + if (!ops.numTeamsLower && !ops.numTeamsUpper && + ops.threadLimitDimsValues.empty()) return false; clauseOps.numTeamsLower = ops.numTeamsLower; clauseOps.numTeamsUpper = ops.numTeamsUpper; - clauseOps.threadLimit = ops.threadLimit; + clauseOps.threadLimitDimsValues = ops.threadLimitDimsValues; return true; } diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index a3b9e5c76bdd2..4d3fec3b0710f 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -767,7 +767,7 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp, innerMapInfos, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(), - targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); + targetOp.getPrivateMapsAttr()); rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(), newTargetOp.getRegion().begin()); rewriter.replaceOp(targetOp, targetDataOp); @@ -1488,7 +1488,7 @@ genPreTargetOp(omp::TargetOp targetOp, SmallVector<Value> &preMapOperands, targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(), - targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); + targetOp.getPrivateMapsAttr()); auto *preTargetBlock = rewriter.createBlock( &preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {}); IRMapping preMapping; @@ -1579,7 +1579,7 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands, targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(), - targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); + targetOp.getPrivateMapsAttr()); auto *isolatedTargetBlock = rewriter.createBlock(&isolatedTargetOp.getRegion(), isolatedTargetOp.getRegion().begin(), {}, {}); @@ -1660,7 +1660,7 @@ static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp, targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(), - targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); + targetOp.getPrivateMapsAttr()); // Create the block for postTargetOp auto *postTargetBlock = rewriter.createBlock( &postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {}); diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index c8d8d0003deef..de39a94c17a6e 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -1453,14 +1453,12 @@ class OpenMP_ThreadLimitClauseSkip< extraClassDeclaration> { let arguments = (ins ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$thread_limit_num_dims, - Variadic<AnyInteger>:$thread_limit_dims_values, - Optional<AnyInteger>:$thread_limit + Variadic<AnyInteger>:$thread_limit_dims_values ); let optAssemblyFormat = [{ `thread_limit` `(` custom<ThreadLimitClause>( - $thread_limit_num_dims, $thread_limit_dims_values, type($thread_limit_dims_values), - $thread_limit, type($thread_limit) + $thread_limit_num_dims, $thread_limit_dims_values, type($thread_limit_dims_values) ) `)` }]; @@ -1468,14 +1466,14 @@ class OpenMP_ThreadLimitClauseSkip< The `thread_limit` clause specifies the limit on the number of threads. With dims modifier: - - The number of dimensions is specified by the `thread_limit_num_dims` attribute. - - The values for each dimension are specified by the `thread_limit_dims_values` attribute. + - The number of dimensions is specified by the `thread_limit_num_dims`. + - The values for each dimension are specified by the `thread_limit_dims_values`. - Format: `thread_limit(dims(N): values : type)` - Example: `thread_limit(dims(2): %n, %m : i64)` Without dims modifier: - - The number of threads is specified by the `thread_limit`. - - Format: `thread_limit(number_of_threads : type)` + - The number of threads is specified by the single value in `thread_limit_dims_values`. + - Format: `thread_limit(value : type)` - Example: `thread_limit(%n : i64)` }]; @@ -1497,6 +1495,8 @@ class OpenMP_ThreadLimitClauseSkip< ::mlir::Value getThreadLimitDimensionValue(unsigned index) { assert(index < getThreadLimitDimsCount() && "Thread limit dims index out of bounds"); + if (getThreadLimitDimsValues().empty()) + return nullptr; return getThreadLimitDimsValues()[index]; } }]; diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index e83419492d28e..4f96b1c079670 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -2211,7 +2211,7 @@ void TargetOp::build(OpBuilder &builder, OperationState &state, clauses.mapVars, clauses.nowait, clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier, clauses.threadLimitNumDims, - clauses.threadLimitDimsValues, clauses.threadLimit, + clauses.threadLimitDimsValues, /*private_maps=*/nullptr); } @@ -2219,15 +2219,7 @@ void TargetOp::build(OpBuilder &builder, OperationState &state, static LogicalResult verifyThreadLimitClause(Operation *op, std::optional<IntegerAttr> threadLimitNumDims, - OperandRange threadLimitDimsValues, Value threadLimit) { - bool hasDimsModifier = - threadLimitNumDims.has_value() && threadLimitNumDims.value(); - - if (hasDimsModifier && threadLimit) { - return op->emitError("thread_limit with dims modifier cannot be used " - "together with number of threads"); - } - + OperandRange threadLimitDimsValues) { if (failed(verifyDimsModifier(op, threadLimitNumDims, threadLimitDimsValues))) return failure(); @@ -2246,8 +2238,7 @@ LogicalResult TargetOp::verify() { return failure(); if (failed(verifyThreadLimitClause(*this, getThreadLimitNumDimsAttr(), - getThreadLimitDimsValues(), - getThreadLimit()))) + getThreadLimitDimsValues()))) return failure(); return verifyPrivateVarsMapping(*this); @@ -2265,10 +2256,9 @@ LogicalResult TargetOp::verifyRegions() { cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) { for (Operation *user : hostEvalArg.getUsers()) { if (auto teamsOp = dyn_cast<TeamsOp>(user)) { - if (llvm::is_contained({teamsOp.getNumTeamsLower(), - teamsOp.getNumTeamsUpper(), - teamsOp.getThreadLimit()}, - hostEvalArg)) + if (teamsOp.getNumTeamsLower() == hostEvalArg || + teamsOp.getNumTeamsUpper() == hostEvalArg || + llvm::is_contained(teamsOp.getThreadLimitDimsValues(), hostEvalArg)) continue; return emitOpError() << "host_eval argument only legal as 'num_teams' " @@ -2714,8 +2704,7 @@ LogicalResult TeamsOp::verify() { // Check for thread_limit clause restrictions if (failed(verifyThreadLimitClause(*this, getThreadLimitNumDimsAttr(), - getThreadLimitDimsValues(), - getThreadLimit()))) + getThreadLimitDimsValues()))) return failure(); return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(), @@ -4632,34 +4621,29 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op, static ParseResult parseThreadLimitClause(OpAsmParser &parser, IntegerAttr &dimsAttr, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, - SmallVectorImpl<Type> &types, - std::optional<OpAsmParser::UnresolvedOperand> &bounds, - Type &boundsType) { + SmallVectorImpl<Type> &types) { + // Try parsing with dims modifier: dims(N): values : type if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) { return success(); } - OpAsmParser::UnresolvedOperand boundsOperand; - if (parser.parseOperand(boundsOperand) || parser.parseColon() || - parser.parseType(boundsType)) { + // Without dims modifier: value : type + OpAsmParser::UnresolvedOperand singleValue; + Type singleType; + if (parser.parseOperand(singleValue) || parser.parseColon() || + parser.parseType(singleType)) { return failure(); } - bounds = boundsOperand; + values.push_back(singleValue); + types.push_back(singleType); return success(); } static void printThreadLimitClause(OpAsmPrinter &p, Operation *op, IntegerAttr dimsAttr, OperandRange values, - TypeRange types, Value bounds, - Type boundsType) { - if (!values.empty()) { - // Multidimensional: dims(N): values : type - printDimsModifierWithValues(p, dimsAttr, values, types); - } else if (bounds) { - // Both bounds: bounds : type - p.printOperand(bounds); - p << " : " << boundsType; - } + TypeRange types) { + // Multidimensional: dims(N): values : type + printDimsModifierWithValues(p, dimsAttr, values, types); } #define GET_ATTRDEF_CLASSES diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 8a3a990e5a3fd..8cb6d4b21b8b2 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -2075,7 +2075,7 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, numTeamsUpper = moduleTranslation.lookupValue(numTeamsUpperVar); llvm::Value *threadLimit = nullptr; - if (Value threadLimitVar = op.getThreadLimit()) + if (Value threadLimitVar = op.getThreadLimitDimensionValue(0)) threadLimit = moduleTranslation.lookupValue(threadLimitVar); llvm::Value *ifExpr = nullptr; @@ -6044,7 +6044,7 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, numTeamsLower = hostEvalVar; else if (teamsOp.getNumTeamsUpper() == blockArg) numTeamsUpper = hostEvalVar; - else if (teamsOp.getThreadLimit() == blockArg) + else if (teamsOp.getThreadLimitDimensionValue(0) == blockArg) threadLimit = hostEvalVar; else llvm_unreachable("unsupported host_eval use"); @@ -6164,7 +6164,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) { numTeamsLower = teamsOp.getNumTeamsLower(); numTeamsUpper = teamsOp.getNumTeamsUpper(); - threadLimit = teamsOp.getThreadLimit(); + threadLimit = teamsOp.getThreadLimitDimensionValue(0); } if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) @@ -6209,7 +6209,8 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, // Extract 'thread_limit' clause from 'target' and 'teams' directives. int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1; - setMaxValueFromClause(targetOp.getThreadLimit(), targetThreadLimitVal); + setMaxValueFromClause(targetOp.getThreadLimitDimensionValue(0), + targetThreadLimitVal); setMaxValueFromClause(threadLimit, teamsThreadLimitVal); // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD. @@ -6288,7 +6289,7 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, teamsThreadLimit, &lowerBounds, &upperBounds, &steps); // TODO: Handle constant 'if' clauses. - if (Value targetThreadLimit = targetOp.getThreadLimit()) + if (Value targetThreadLimit = targetOp.getThreadLimitDimensionValue(0)) attrs.TargetThreadLimit.front() = moduleTranslation.lookupValue(targetThreadLimit); diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index e841e65d36292..649c0fde35ee6 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -1438,7 +1438,7 @@ func.func @omp_teams_allocate(%data_var : memref<i32>) { // expected-error @below {{expected equal sizes for allocate and allocator variables}} "omp.teams" (%data_var) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0,0,0>} : (memref<i32>) -> () + }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0,0>} : (memref<i32>) -> () omp.terminator } return @@ -1451,7 +1451,7 @@ func.func @omp_teams_num_teams1(%lb : i32) { // expected-error @below {{expected num_teams upper bound to be defined if the lower bound is defined}} "omp.teams" (%lb) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,0,0,0>} : (i32) -> () + }) {operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,0,0>} : (i32) -> () omp.terminator } return @@ -1496,7 +1496,7 @@ func.func @omp_teams_thread_limit_dims_mismatch() { // expected-error @below {{dims(3) specified but 2 values provided}} "omp.teams" (%v0, %v1) ({ omp.terminator - }) {thread_limit_num_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2,0>} : (i32, i32) -> () + }) {thread_limit_num_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2>} : (i32, i32) -> () omp.terminator } return @@ -1509,10 +1509,10 @@ func.func @omp_teams_thread_limit_dims_with_scalar() { %v0 = arith.constant 1 : i32 %v1 = arith.constant 2 : i32 %tl = arith.constant 4 : i32 - // expected-error @below {{thread_limit with dims modifier cannot be used together with number of threads}} + // expected-error @below {{dims(2) specified but 3 values provided}} "omp.teams" (%v0, %v1, %tl) ({ omp.terminator - }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2,1>} : (i32, i32, i32) -> () + }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,3>} : (i32, i32, i32) -> () omp.terminator } return @@ -1540,7 +1540,7 @@ func.func @omp_teams_thread_limit_values_without_dims() { // expected-error @below {{dims values can only be specified with dims modifier}} "omp.teams" (%v0, %v1) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2,0>} : (i32, i32) -> () + }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2>} : (i32, i32) -> () omp.terminator } return @@ -1555,7 +1555,7 @@ func.func @omp_teams_thread_limit_dims_type_mismatch() { // expected-error @below {{dims modifier requires all values to have the same type}} "omp.teams" (%v0, %v1) ({ omp.terminator - }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2,0>} : (i32, i64) -> () + }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2>} : (i32, i64) -> () omp.terminator } return @@ -1569,7 +1569,7 @@ func.func @omp_target_thread_limit_dims_mismatch() { // expected-error @below {{dims(3) specified but 2 values provided}} "omp.target" (%v0, %v1) ({ omp.terminator - }) {thread_limit_num_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2,0>} : (i32, i32) -> () + }) {thread_limit_num_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2>} : (i32, i32) -> () return } @@ -1579,10 +1579,10 @@ func.func @omp_target_thread_limit_dims_with_scalar() { %v0 = arith.constant 1 : i32 %v1 = arith.constant 2 : i32 %tl = arith.constant 4 : i32 - // expected-error @below {{thread_limit with dims modifier cannot be used together with number of threads}} + // expected-error @below {{dims(2) specified but 3 values provided}} "omp.target" (%v0, %v1, %tl) ({ omp.terminator - }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2,1>} : (i32, i32, i32) -> () + }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,3>} : (i32, i32, i32) -> () return } @@ -1592,7 +1592,7 @@ func.func @omp_target_thread_limit_dims_no_values() { // expected-error @below {{dims modifier requires values to be specified}} "omp.target" () ({ omp.terminator - }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,0,0>} : () -> () + }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,0>} : () -> () return } @@ -1604,7 +1604,7 @@ func.func @omp_target_thread_limit_values_without_dims() { // expected-error @below {{dims values can only be specified with dims modifier}} "omp.target" (%v0, %v1) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2,0>} : (i32, i32) -> () + }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2>} : (i32, i32) -> () return } @@ -1616,7 +1616,7 @@ func.func @omp_target_thread_limit_dims_type_mismatch() { // expected-error @below {{dims modifier requires all values to have the same type}} "omp.target" (%v0, %v1) ({ omp.terminator - }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2,0>} : (i32, i64) -> () + }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2>} : (i32, i64) -> () return } @@ -2608,7 +2608,7 @@ func.func @omp_target_depend(%data_var: memref<i32>) { // expected-error @below {{op expected as many depend values as depend variables}} "omp.target"(%data_var) ({ "omp.terminator"() : () -> () - }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> () + }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> () "func.return"() : () -> () } diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 39ddc5bfa4e50..51b1eed766ac3 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -824,7 +824,7 @@ func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %devic "omp.target"(%device, %if_cond, %num_threads) ({ // CHECK: omp.terminator omp.terminator - }) {nowait, operandSegmentSizes = array<i32: 0,0,0,1,0,0,1,0,0,0,0,0,1>} : ( si32, i1, i32 ) -> () + }) {nowait, operandSegmentSizes = array<i32: 0,0,0,1,0,0,1,0,0,0,0,1>} : ( si32, i1, i32 ) -> () // Test with optional map clause. // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(always, to) capture(ByRef) -> memref<?xi32> {name = ""} >From 243d8dc86b3602cacc18e754fb41a935f502cdd4 Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Fri, 19 Dec 2025 10:11:01 +0530 Subject: [PATCH 5/9] comments fixes --- mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index de39a94c17a6e..7d2810018e45f 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -1453,7 +1453,7 @@ class OpenMP_ThreadLimitClauseSkip< extraClassDeclaration> { let arguments = (ins ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$thread_limit_num_dims, - Variadic<AnyInteger>:$thread_limit_dims_values + Variadic<IntLikeType>:$thread_limit_dims_values ); let optAssemblyFormat = [{ >From 7876b6b7015f09102b57970a6d04a909de953dea Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Fri, 19 Dec 2025 15:07:17 +0530 Subject: [PATCH 6/9] fix comment --- .../Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 8cb6d4b21b8b2..7b2426e860d8d 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -6040,6 +6040,7 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, for (Operation *user : blockArg.getUsers()) { llvm::TypeSwitch<Operation *>(user) .Case([&](omp::TeamsOp teamsOp) { + // num_teams dims and values are not yet supported if (teamsOp.getNumTeamsLower() == blockArg) numTeamsLower = hostEvalVar; else if (teamsOp.getNumTeamsUpper() == blockArg) >From f40deb8a3b3752d28ffcc3e6cb983a9ca2ac66e9 Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Wed, 14 Jan 2026 12:19:24 +0530 Subject: [PATCH 7/9] [Flang] Add missing threadLimitNumDims in TeamsOperands apply method --- flang/lib/Lower/OpenMP/OpenMP.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 1021742b87b2f..f05c918ed3cde 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -193,6 +193,7 @@ class HostEvalInfo { clauseOps.numTeamsLower = ops.numTeamsLower; clauseOps.numTeamsUpper = ops.numTeamsUpper; clauseOps.threadLimitDimsValues = ops.threadLimitDimsValues; + clauseOps.threadLimitNumDims = ops.threadLimitNumDims; return true; } >From f72662402e2d1f3d96476b305dea242550c7ea78 Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Fri, 16 Jan 2026 09:55:34 +0530 Subject: [PATCH 8/9] remove dims(N) syntax and just use list for dims vals --- flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 2 +- flang/lib/Lower/OpenMP/OpenMP.cpp | 15 +- .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 12 +- flang/test/Lower/OpenMP/teams.f90 | 2 +- .../mlir/Dialect/OpenMP/OpenMPClauses.td | 48 +++---- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 56 ++------ .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 24 +++- mlir/test/Dialect/OpenMP/invalid.mlir | 133 ------------------ mlir/test/Dialect/OpenMP/ops.mlir | 15 +- mlir/test/Target/LLVMIR/openmp-todo.mlir | 11 ++ 10 files changed, 87 insertions(+), 231 deletions(-) diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index 18bab01d94365..8083b1b10aee7 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -670,7 +670,7 @@ bool ClauseProcessor::processThreadLimit( if (auto *clause = findUniqueClause<omp::clause::ThreadLimit>()) { mlir::Value threadLimitVal = fir::getBase(converter.genExprValue(clause->v, stmtCtx)); - result.threadLimitDimsValues.push_back(threadLimitVal); + result.threadLimitVals.push_back(threadLimitVal); return true; } return false; diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index f05c918ed3cde..b8068ee09cf81 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -102,8 +102,7 @@ class HostEvalInfo { if (ops.numThreads) vars.push_back(ops.numThreads); - // Old spec: single value in threadLimitDimsValues - for (mlir::Value val : ops.threadLimitDimsValues) + for (mlir::Value val : ops.threadLimitVals) vars.push_back(val); } @@ -117,7 +116,7 @@ class HostEvalInfo { ops.loopLowerBounds.size() + ops.loopUpperBounds.size() + ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) + (ops.numTeamsUpper ? 1 : 0) + (ops.numThreads ? 1 : 0) + - ops.threadLimitDimsValues.size() && + ops.threadLimitVals.size() && "invalid block argument list"); int argIndex = 0; for (size_t i = 0; i < ops.loopLowerBounds.size(); ++i) @@ -138,8 +137,8 @@ class HostEvalInfo { if (ops.numThreads) ops.numThreads = args[argIndex++]; - for (size_t i = 0; i < ops.threadLimitDimsValues.size(); ++i) - ops.threadLimitDimsValues[i] = args[argIndex++]; + for (size_t i = 0; i < ops.threadLimitVals.size(); ++i) + ops.threadLimitVals[i] = args[argIndex++]; } /// Update \p clauseOps and \p ivOut with the corresponding host-evaluated @@ -186,14 +185,12 @@ class HostEvalInfo { /// \returns whether an update was performed. If not, these clauses were not /// evaluated in the host device. bool apply(mlir::omp::TeamsOperands &clauseOps) { - if (!ops.numTeamsLower && !ops.numTeamsUpper && - ops.threadLimitDimsValues.empty()) + if (!ops.numTeamsLower && !ops.numTeamsUpper && ops.threadLimitVals.empty()) return false; clauseOps.numTeamsLower = ops.numTeamsLower; clauseOps.numTeamsUpper = ops.numTeamsUpper; - clauseOps.threadLimitDimsValues = ops.threadLimitDimsValues; - clauseOps.threadLimitNumDims = ops.threadLimitNumDims; + clauseOps.threadLimitVals = ops.threadLimitVals; return true; } diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index 4d3fec3b0710f..b804a14e32f0c 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -766,8 +766,7 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp, targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), innerMapInfos, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(), - targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(), - targetOp.getPrivateMapsAttr()); + targetOp.getThreadLimitVals(), targetOp.getPrivateMapsAttr()); rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(), newTargetOp.getRegion().begin()); rewriter.replaceOp(targetOp, targetDataOp); @@ -1486,8 +1485,7 @@ genPreTargetOp(omp::TargetOp targetOp, SmallVector<Value> &preMapOperands, targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), preMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), - targetOp.getPrivateNeedsBarrierAttr(), - targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(), + targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimitVals(), targetOp.getPrivateMapsAttr()); auto *preTargetBlock = rewriter.createBlock( &preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {}); @@ -1577,8 +1575,7 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands, targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), - targetOp.getPrivateNeedsBarrierAttr(), - targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(), + targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimitVals(), targetOp.getPrivateMapsAttr()); auto *isolatedTargetBlock = rewriter.createBlock(&isolatedTargetOp.getRegion(), @@ -1658,8 +1655,7 @@ static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp, targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), - targetOp.getPrivateNeedsBarrierAttr(), - targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(), + targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimitVals(), targetOp.getPrivateMapsAttr()); // Create the block for postTargetOp auto *postTargetBlock = rewriter.createBlock( diff --git a/flang/test/Lower/OpenMP/teams.f90 b/flang/test/Lower/OpenMP/teams.f90 index 47d379d6c2842..e5ba7070cf664 100644 --- a/flang/test/Lower/OpenMP/teams.f90 +++ b/flang/test/Lower/OpenMP/teams.f90 @@ -21,7 +21,7 @@ subroutine teams_numteams(num_teams) integer, intent(inout) :: num_teams ! CHECK: omp.teams - ! CHECK-SAME: num_teams( to %{{.*}}: i32) + ! CHECK-SAME: num_teams(to %{{.*}}: i32) !$omp teams num_teams(4) ! CHECK: fir.call call f1() diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index 7d2810018e45f..9d4a01e9edf13 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -1452,52 +1452,48 @@ class OpenMP_ThreadLimitClauseSkip< > : OpenMP_Clause<traits, arguments, assemblyFormat, description, extraClassDeclaration> { let arguments = (ins - ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$thread_limit_num_dims, - Variadic<IntLikeType>:$thread_limit_dims_values + Variadic<IntLikeType>:$thread_limit_vals ); let optAssemblyFormat = [{ `thread_limit` `(` custom<ThreadLimitClause>( - $thread_limit_num_dims, $thread_limit_dims_values, type($thread_limit_dims_values) + $thread_limit_vals, type($thread_limit_vals) ) `)` }]; let description = [{ The `thread_limit` clause specifies the limit on the number of threads. - With dims modifier: - - The number of dimensions is specified by the `thread_limit_num_dims`. - - The values for each dimension are specified by the `thread_limit_dims_values`. - - Format: `thread_limit(dims(N): values : type)` - - Example: `thread_limit(dims(2): %n, %m : i64)` + Multi-dimensional format (dims modifier): + - Multiple values can be specified for multi-dimensional thread limits. + - The number of dimensions is derived from the number of values. + - Values can have different integer types. + - Format: `thread_limit(%v1, %v2, ... : type1, type2, ...)` + - Example: `thread_limit(%n, %m : i32, i64)` - Without dims modifier: - - The number of threads is specified by the single value in `thread_limit_dims_values`. - - Format: `thread_limit(value : type)` - - Example: `thread_limit(%n : i64)` + Single value format: + - A single value specifies the thread limit. + - Format: `thread_limit(%value : type)` + - Example: `thread_limit(%n : i32)` }]; let extraClassDeclaration = [{ - /// Returns true if the dims modifier is explicitly present - bool hasThreadLimitDimsModifier() { - return getThreadLimitNumDims().has_value() && getThreadLimitNumDims().value(); + /// Returns true if using multi-dimensional values (more than one value) + bool hasThreadLimitMultiDim() { + return getThreadLimitVals().size() > 1; } - /// Returns the number of dimensions specified by dims modifier + /// Returns the number of dimensions specified for thread_limit unsigned getThreadLimitDimsCount() { - if (!hasThreadLimitDimsModifier()) - return 1; - return static_cast<unsigned>(*getThreadLimitNumDims()); + return getThreadLimitVals().size(); } /// Returns the value for a specific dimension index - /// Index must be less than getThreadLimitDimsCount() - ::mlir::Value getThreadLimitDimensionValue(unsigned index) { - assert(index < getThreadLimitDimsCount() && - "Thread limit dims index out of bounds"); - if (getThreadLimitDimsValues().empty()) - return nullptr; - return getThreadLimitDimsValues()[index]; + /// Index must be less than getThreadLimitVals().size() + ::mlir::Value getThreadLimitVal(unsigned index) { + assert(index < getThreadLimitVals().size() && + "Thread limit index out of bounds"); + return getThreadLimitVals()[index]; } }]; } diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 4f96b1c079670..c22830f3b08ec 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -2210,22 +2210,10 @@ void TargetOp::build(OpBuilder &builder, OperationState &state, /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars, clauses.mapVars, clauses.nowait, clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), - clauses.privateNeedsBarrier, clauses.threadLimitNumDims, - clauses.threadLimitDimsValues, + clauses.privateNeedsBarrier, clauses.threadLimitVals, /*private_maps=*/nullptr); } -// helper for thread_limit clause restrictions -static LogicalResult -verifyThreadLimitClause(Operation *op, - std::optional<IntegerAttr> threadLimitNumDims, - OperandRange threadLimitDimsValues) { - if (failed(verifyDimsModifier(op, threadLimitNumDims, threadLimitDimsValues))) - return failure(); - - return success(); -} - LogicalResult TargetOp::verify() { if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars()))) return failure(); @@ -2237,10 +2225,6 @@ LogicalResult TargetOp::verify() { if (failed(verifyMapClause(*this, getMapVars()))) return failure(); - if (failed(verifyThreadLimitClause(*this, getThreadLimitNumDimsAttr(), - getThreadLimitDimsValues()))) - return failure(); - return verifyPrivateVarsMapping(*this); } @@ -2258,7 +2242,7 @@ LogicalResult TargetOp::verifyRegions() { if (auto teamsOp = dyn_cast<TeamsOp>(user)) { if (teamsOp.getNumTeamsLower() == hostEvalArg || teamsOp.getNumTeamsUpper() == hostEvalArg || - llvm::is_contained(teamsOp.getThreadLimitDimsValues(), hostEvalArg)) + llvm::is_contained(teamsOp.getThreadLimitVals(), hostEvalArg)) continue; return emitOpError() << "host_eval argument only legal as 'num_teams' " @@ -2647,7 +2631,7 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state, clauses.reductionVars, makeDenseBoolArrayAttr(ctx, clauses.reductionByref), makeArrayAttr(ctx, clauses.reductionSyms), - clauses.threadLimit); + clauses.threadLimitVals); } // Verify num_teams clause @@ -2702,11 +2686,6 @@ LogicalResult TeamsOp::verify() { return emitError( "expected equal sizes for allocate and allocator variables"); - // Check for thread_limit clause restrictions - if (failed(verifyThreadLimitClause(*this, getThreadLimitNumDimsAttr(), - getThreadLimitDimsValues()))) - return failure(); - return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(), getReductionByref()); } @@ -4608,7 +4587,7 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op, p << " : " << upperBoundType; } else { // Upper only: to upper : type - p << "to "; + p << " to "; p.printOperand(upperBound); p << " : " << upperBoundType; } @@ -4619,31 +4598,24 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op, // Parser and printer for thread_limit clause //===----------------------------------------------------------------------===// static ParseResult -parseThreadLimitClause(OpAsmParser &parser, IntegerAttr &dimsAttr, +parseThreadLimitClause(OpAsmParser &parser, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, SmallVectorImpl<Type> &types) { - // Try parsing with dims modifier: dims(N): values : type - if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) { - return success(); - } - - // Without dims modifier: value : type - OpAsmParser::UnresolvedOperand singleValue; - Type singleType; - if (parser.parseOperand(singleValue) || parser.parseColon() || - parser.parseType(singleType)) { + // Parse comma-separated list of values with their types + // Format: %v1, %v2, ... : type1, type2, ... + if (parser.parseOperandList(values) || parser.parseColon() || + parser.parseTypeList(types)) { return failure(); } - values.push_back(singleValue); - types.push_back(singleType); return success(); } static void printThreadLimitClause(OpAsmPrinter &p, Operation *op, - IntegerAttr dimsAttr, OperandRange values, - TypeRange types) { - // Multidimensional: dims(N): values : type - printDimsModifierWithValues(p, dimsAttr, values, types); + OperandRange values, TypeRange types) { + // Print values with their types + llvm::interleaveComma(values, p, [&](Value v) { p << v; }); + p << " : "; + llvm::interleaveComma(types, p, [&](Type t) { p << t; }); } #define GET_ATTRDEF_CLASSES diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 7b2426e860d8d..7abbeaedc446d 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -380,6 +380,10 @@ static LogicalResult checkImplementationStatus(Operation &op) { if (op.hasNumTeamsMultiDim()) result = todo("num_teams with multi-dimensional values"); }; + auto checkThreadLimitMultiDim = [&todo](auto op, LogicalResult &result) { + if (op.hasThreadLimitMultiDim()) + result = todo("thread_limit with multi-dimensional values"); + }; LogicalResult result = success(); llvm::TypeSwitch<Operation &>(op) @@ -405,6 +409,7 @@ static LogicalResult checkImplementationStatus(Operation &op) { checkAllocate(op, result); checkPrivate(op, result); checkNumTeamsMultiDim(op, result); + checkThreadLimitMultiDim(op, result); }) .Case([&](omp::TaskOp op) { checkAllocate(op, result); @@ -442,6 +447,7 @@ static LogicalResult checkImplementationStatus(Operation &op) { checkAllocate(op, result); checkBare(op, result); checkInReduction(op, result); + checkThreadLimitMultiDim(op, result); }) .Default([](Operation &) { // Assume all clauses for an operation can be translated unless they are @@ -2075,8 +2081,8 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, numTeamsUpper = moduleTranslation.lookupValue(numTeamsUpperVar); llvm::Value *threadLimit = nullptr; - if (Value threadLimitVar = op.getThreadLimitDimensionValue(0)) - threadLimit = moduleTranslation.lookupValue(threadLimitVar); + if (!op.getThreadLimitVals().empty()) + threadLimit = moduleTranslation.lookupValue(op.getThreadLimitVal(0)); llvm::Value *ifExpr = nullptr; if (Value ifVar = op.getIfExpr()) @@ -6045,7 +6051,8 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, numTeamsLower = hostEvalVar; else if (teamsOp.getNumTeamsUpper() == blockArg) numTeamsUpper = hostEvalVar; - else if (teamsOp.getThreadLimitDimensionValue(0) == blockArg) + else if (!teamsOp.getThreadLimitVals().empty() && + teamsOp.getThreadLimitVal(0) == blockArg) threadLimit = hostEvalVar; else llvm_unreachable("unsupported host_eval use"); @@ -6165,7 +6172,8 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) { numTeamsLower = teamsOp.getNumTeamsLower(); numTeamsUpper = teamsOp.getNumTeamsUpper(); - threadLimit = teamsOp.getThreadLimitDimensionValue(0); + if (!teamsOp.getThreadLimitVals().empty()) + threadLimit = teamsOp.getThreadLimitVal(0); } if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) @@ -6210,8 +6218,8 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, // Extract 'thread_limit' clause from 'target' and 'teams' directives. int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1; - setMaxValueFromClause(targetOp.getThreadLimitDimensionValue(0), - targetThreadLimitVal); + if (!targetOp.getThreadLimitVals().empty()) + setMaxValueFromClause(targetOp.getThreadLimitVal(0), targetThreadLimitVal); setMaxValueFromClause(threadLimit, teamsThreadLimitVal); // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD. @@ -6290,9 +6298,11 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, teamsThreadLimit, &lowerBounds, &upperBounds, &steps); // TODO: Handle constant 'if' clauses. - if (Value targetThreadLimit = targetOp.getThreadLimitDimensionValue(0)) + if (!targetOp.getThreadLimitVals().empty()) { + Value targetThreadLimit = targetOp.getThreadLimitVal(0); attrs.TargetThreadLimit.front() = moduleTranslation.lookupValue(targetThreadLimit); + } if (numTeamsLower) attrs.MinTeams = moduleTranslation.lookupValue(numTeamsLower); diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 649c0fde35ee6..bb882db73cbab 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -1489,139 +1489,6 @@ func.func @omp_teams_num_teams2(%lb : i32, %ub : i16) { // ----- -func.func @omp_teams_thread_limit_dims_mismatch() { - omp.target { - %v0 = arith.constant 1 : i32 - %v1 = arith.constant 2 : i32 - // expected-error @below {{dims(3) specified but 2 values provided}} - "omp.teams" (%v0, %v1) ({ - omp.terminator - }) {thread_limit_num_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2>} : (i32, i32) -> () - omp.terminator - } - return -} - -// ----- - -func.func @omp_teams_thread_limit_dims_with_scalar() { - omp.target { - %v0 = arith.constant 1 : i32 - %v1 = arith.constant 2 : i32 - %tl = arith.constant 4 : i32 - // expected-error @below {{dims(2) specified but 3 values provided}} - "omp.teams" (%v0, %v1, %tl) ({ - omp.terminator - }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,3>} : (i32, i32, i32) -> () - omp.terminator - } - return -} - -// ----- - -func.func @omp_teams_thread_limit_dims_no_values() { - omp.target { - // expected-error @below {{dims modifier requires values to be specified}} - "omp.teams" () ({ - omp.terminator - }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0>} : () -> () - omp.terminator - } - return -} - -// ----- - -func.func @omp_teams_thread_limit_values_without_dims() { - omp.target { - %v0 = arith.constant 1 : i32 - %v1 = arith.constant 2 : i32 - // expected-error @below {{dims values can only be specified with dims modifier}} - "omp.teams" (%v0, %v1) ({ - omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2>} : (i32, i32) -> () - omp.terminator - } - return -} - -// ----- - -func.func @omp_teams_thread_limit_dims_type_mismatch() { - omp.target { - %v0 = arith.constant 1 : i32 - %v1 = arith.constant 2 : i64 - // expected-error @below {{dims modifier requires all values to have the same type}} - "omp.teams" (%v0, %v1) ({ - omp.terminator - }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2>} : (i32, i64) -> () - omp.terminator - } - return -} - -// ----- - -func.func @omp_target_thread_limit_dims_mismatch() { - %v0 = arith.constant 1 : i32 - %v1 = arith.constant 2 : i32 - // expected-error @below {{dims(3) specified but 2 values provided}} - "omp.target" (%v0, %v1) ({ - omp.terminator - }) {thread_limit_num_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2>} : (i32, i32) -> () - return -} - -// ----- - -func.func @omp_target_thread_limit_dims_with_scalar() { - %v0 = arith.constant 1 : i32 - %v1 = arith.constant 2 : i32 - %tl = arith.constant 4 : i32 - // expected-error @below {{dims(2) specified but 3 values provided}} - "omp.target" (%v0, %v1, %tl) ({ - omp.terminator - }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,3>} : (i32, i32, i32) -> () - return -} - -// ----- - -func.func @omp_target_thread_limit_dims_no_values() { - // expected-error @below {{dims modifier requires values to be specified}} - "omp.target" () ({ - omp.terminator - }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,0>} : () -> () - return -} - -// ----- - -func.func @omp_target_thread_limit_values_without_dims() { - %v0 = arith.constant 1 : i32 - %v1 = arith.constant 2 : i32 - // expected-error @below {{dims values can only be specified with dims modifier}} - "omp.target" (%v0, %v1) ({ - omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2>} : (i32, i32) -> () - return -} - -// ----- - -func.func @omp_target_thread_limit_dims_type_mismatch() { - %v0 = arith.constant 1 : i32 - %v1 = arith.constant 2 : i64 - // expected-error @below {{dims modifier requires all values to have the same type}} - "omp.target" (%v0, %v1) ({ - omp.terminator - }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2>} : (i32, i64) -> () - return -} - -// ----- - func.func @omp_sections(%data_var : memref<i32>) -> () { // expected-error @below {{expected equal sizes for allocate and allocator variables}} "omp.sections" (%data_var) ({ diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 51b1eed766ac3..4e5acca796584 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -1103,7 +1103,7 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32, omp.terminator } - // CHECK: omp.teams num_teams(to %{{.+}} : i32) + // CHECK: omp.teams num_teams( to %{{.+}} : i32) omp.teams num_teams(to %ub : i32) { // CHECK: omp.terminator omp.terminator @@ -1136,8 +1136,15 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32, omp.terminator } - // CHECK: omp.teams thread_limit(dims(2): %{{.*}}, %{{.*}} : i32) - omp.teams thread_limit(dims(2): %lb, %ub : i32) { + // CHECK: omp.teams thread_limit(%{{.*}}, %{{.*}} : i32, i32) + omp.teams thread_limit(%lb, %ub : i32, i32) { + // CHECK: omp.terminator + omp.terminator + } + + // Test thread_limit with mixed types. + // CHECK: omp.teams thread_limit(%{{.*}}, %{{.*}}, %{{.*}} : i32, i64, i16) + omp.teams thread_limit(%lb, %ub64, %ub16 : i32, i64, i16) { // CHECK: omp.terminator omp.terminator } @@ -3090,7 +3097,7 @@ func.func @omp_target_private_with_map_idx(%map1: memref<?xi32>, %map2: memref<? func.func @omp_target_host_eval(%x : i32) { // CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) { - // CHECK: omp.teams num_teams(to %[[HOST_ARG]] : i32) + // CHECK: omp.teams num_teams( to %[[HOST_ARG]] : i32) // CHECK-SAME: thread_limit(%[[HOST_ARG]] : i32) omp.target host_eval(%x -> %arg0 : i32) { omp.teams num_teams( to %arg0 : i32) thread_limit(%arg0 : i32) { diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir index 3681ce38bd523..c766cc9568b4f 100644 --- a/mlir/test/Target/LLVMIR/openmp-todo.mlir +++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir @@ -452,6 +452,17 @@ llvm.func @teams_num_teams_multi_dim(%lb : i32, %ub : i32) { // ----- +llvm.func @teams_thread_limit_multi_dim(%lb : i32, %ub : i32) { + // expected-error@below {{not yet implemented: Unhandled clause thread_limit with multi-dimensional values in omp.teams operation}} + // expected-error@below {{LLVM Translation failed for operation: omp.teams}} + omp.teams thread_limit(%lb, %ub : i32, i32) { + omp.terminator + } + llvm.return +} + +// ----- + llvm.func @wsloop_allocate(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { // expected-error@below {{not yet implemented: Unhandled clause allocate in omp.wsloop operation}} // expected-error@below {{LLVM Translation failed for operation: omp.wsloop}} >From bea50e20b09934bad867e6b847517785dedff923 Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Sat, 17 Jan 2026 10:21:49 +0530 Subject: [PATCH 9/9] remove custom parser/printer for dims --- .../mlir/Dialect/OpenMP/OpenMPClauses.td | 14 +++++------ mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 24 ------------------- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 16 ++++++------- 3 files changed, 14 insertions(+), 40 deletions(-) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index 9d4a01e9edf13..1970e2115003f 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -1456,9 +1456,7 @@ class OpenMP_ThreadLimitClauseSkip< ); let optAssemblyFormat = [{ - `thread_limit` `(` custom<ThreadLimitClause>( - $thread_limit_vals, type($thread_limit_vals) - ) `)` + `thread_limit` `(` $thread_limit_vals `:` type($thread_limit_vals) `)` }]; let description = [{ @@ -1488,12 +1486,12 @@ class OpenMP_ThreadLimitClauseSkip< return getThreadLimitVals().size(); } - /// Returns the value for a specific dimension index - /// Index must be less than getThreadLimitVals().size() - ::mlir::Value getThreadLimitVal(unsigned index) { - assert(index < getThreadLimitVals().size() && + /// Returns the value for a specific dimension + /// dim must be less than getThreadLimitDimsCount() + ::mlir::Value getThreadLimit(unsigned dim = 0) { + assert(dim < getThreadLimitDimsCount() && "Thread limit index out of bounds"); - return getThreadLimitVals()[index]; + return getThreadLimitVals()[dim]; } }]; } diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index c22830f3b08ec..d5532f959ae1b 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -4594,30 +4594,6 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op, } } -//===----------------------------------------------------------------------===// -// Parser and printer for thread_limit clause -//===----------------------------------------------------------------------===// -static ParseResult -parseThreadLimitClause(OpAsmParser &parser, - SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, - SmallVectorImpl<Type> &types) { - // Parse comma-separated list of values with their types - // Format: %v1, %v2, ... : type1, type2, ... - if (parser.parseOperandList(values) || parser.parseColon() || - parser.parseTypeList(types)) { - return failure(); - } - return success(); -} - -static void printThreadLimitClause(OpAsmPrinter &p, Operation *op, - OperandRange values, TypeRange types) { - // Print values with their types - llvm::interleaveComma(values, p, [&](Value v) { p << v; }); - p << " : "; - llvm::interleaveComma(types, p, [&](Type t) { p << t; }); -} - #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 7abbeaedc446d..a25654926ca01 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -380,7 +380,7 @@ static LogicalResult checkImplementationStatus(Operation &op) { if (op.hasNumTeamsMultiDim()) result = todo("num_teams with multi-dimensional values"); }; - auto checkThreadLimitMultiDim = [&todo](auto op, LogicalResult &result) { + auto checkThreadLimit = [&todo](auto op, LogicalResult &result) { if (op.hasThreadLimitMultiDim()) result = todo("thread_limit with multi-dimensional values"); }; @@ -409,7 +409,7 @@ static LogicalResult checkImplementationStatus(Operation &op) { checkAllocate(op, result); checkPrivate(op, result); checkNumTeamsMultiDim(op, result); - checkThreadLimitMultiDim(op, result); + checkThreadLimit(op, result); }) .Case([&](omp::TaskOp op) { checkAllocate(op, result); @@ -447,7 +447,7 @@ static LogicalResult checkImplementationStatus(Operation &op) { checkAllocate(op, result); checkBare(op, result); checkInReduction(op, result); - checkThreadLimitMultiDim(op, result); + checkThreadLimit(op, result); }) .Default([](Operation &) { // Assume all clauses for an operation can be translated unless they are @@ -2082,7 +2082,7 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, llvm::Value *threadLimit = nullptr; if (!op.getThreadLimitVals().empty()) - threadLimit = moduleTranslation.lookupValue(op.getThreadLimitVal(0)); + threadLimit = moduleTranslation.lookupValue(op.getThreadLimit(0)); llvm::Value *ifExpr = nullptr; if (Value ifVar = op.getIfExpr()) @@ -6052,7 +6052,7 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, else if (teamsOp.getNumTeamsUpper() == blockArg) numTeamsUpper = hostEvalVar; else if (!teamsOp.getThreadLimitVals().empty() && - teamsOp.getThreadLimitVal(0) == blockArg) + teamsOp.getThreadLimit(0) == blockArg) threadLimit = hostEvalVar; else llvm_unreachable("unsupported host_eval use"); @@ -6173,7 +6173,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, numTeamsLower = teamsOp.getNumTeamsLower(); numTeamsUpper = teamsOp.getNumTeamsUpper(); if (!teamsOp.getThreadLimitVals().empty()) - threadLimit = teamsOp.getThreadLimitVal(0); + threadLimit = teamsOp.getThreadLimit(0); } if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) @@ -6219,7 +6219,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, // Extract 'thread_limit' clause from 'target' and 'teams' directives. int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1; if (!targetOp.getThreadLimitVals().empty()) - setMaxValueFromClause(targetOp.getThreadLimitVal(0), targetThreadLimitVal); + setMaxValueFromClause(targetOp.getThreadLimit(0), targetThreadLimitVal); setMaxValueFromClause(threadLimit, teamsThreadLimitVal); // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD. @@ -6299,7 +6299,7 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, // TODO: Handle constant 'if' clauses. if (!targetOp.getThreadLimitVals().empty()) { - Value targetThreadLimit = targetOp.getThreadLimitVal(0); + Value targetThreadLimit = targetOp.getThreadLimit(0); attrs.TargetThreadLimit.front() = moduleTranslation.lookupValue(targetThreadLimit); } _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
