https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/171767
>From 6093bdcf18e36ad0ef1b97c6c2cac8b8cd9000c3 Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Thu, 11 Dec 2025 11:56:58 +0530 Subject: [PATCH 1/7] [OpenMP][MLIR] Add num_threads clause with dims modifier support --- .../mlir/Dialect/OpenMP/OpenMPClauses.td | 50 +++++++++++- .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 2 + mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 79 +++++++++++++++++-- mlir/test/Dialect/OpenMP/invalid.mlir | 33 +++++++- mlir/test/Dialect/OpenMP/ops.mlir | 15 ++-- 5 files changed, 163 insertions(+), 16 deletions(-) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index d4640f254ed1f..aedfa05da1608 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -1069,16 +1069,60 @@ class OpenMP_NumThreadsClauseSkip< > : OpenMP_Clause<traits, arguments, assemblyFormat, description, extraClassDeclaration> { let arguments = (ins + ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_dims, + Variadic<AnyInteger>:$num_threads_values, Optional<IntLikeType>:$num_threads ); let optAssemblyFormat = [{ - `num_threads` `(` $num_threads `:` type($num_threads) `)` + `num_threads` `(` custom<NumThreadsClause>( + $num_threads_dims, $num_threads_values, type($num_threads_values), + $num_threads, type($num_threads) + ) `)` }]; let description = [{ - The optional `num_threads` parameter specifies the number of threads which - should be used to execute the parallel region. + num_threads clause specifies the desired number of threads in the team + space formed by the construct on which it appears. + + With dims modifier: + - Uses `num_threads_dims` (dimension count) and `num_threads_values` (upper bounds list) + - Specifies upper bounds for each dimension (all must have same type) + - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)` + - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)` + + Without dims modifier: + - Uses `num_threads` + - If lower bound not specified, it defaults to upper bound value + - Format: `num_threads(bounds : type)` + - Example: `num_threads(%ub : i32)` + }]; + + let extraClassDeclaration = [{ + /// Returns true if the dims modifier is explicitly present + bool hasDimsModifier() { + return getNumThreadsDims().has_value(); + } + + /// Returns the number of dimensions specified by dims modifier + unsigned getNumDimensions() { + if (!hasDimsModifier()) + return 1; + return static_cast<unsigned>(*getNumThreadsDims()); + } + + /// Returns all dimension values as an operand range + ::mlir::OperandRange getDimensionValues() { + return getNumThreadsValues(); + } + + /// Returns the value for a specific dimension index + /// Index must be less than getNumDimensions() + ::mlir::Value getDimensionValue(unsigned index) { + assert(index < getDimensionValues().size() && + "Dimension index out of bounds"); + return getDimensionValues()[index]; + } }]; } diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index 6423d49859c97..0d5333ec2e455 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -448,6 +448,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { /* allocate_vars = */ llvm::SmallVector<Value>{}, /* allocator_vars = */ llvm::SmallVector<Value>{}, /* if_expr = */ Value{}, + /* num_threads_dims = */ nullptr, + /* num_threads_values = */ llvm::SmallVector<Value>{}, /* num_threads = */ numThreadsVar, /* private_vars = */ ValueRange(), /* private_syms = */ nullptr, diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 67ff9023a38da..9664b8f59802c 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -2504,6 +2504,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, ArrayRef<NamedAttribute> attributes) { ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(), /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr, + /*num_threads_dims=*/nullptr, + /*num_threads_values=*/ValueRange(), /*num_threads=*/nullptr, /*private_vars=*/ValueRange(), /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr, /*proc_bind_kind=*/nullptr, @@ -2515,13 +2517,14 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, void ParallelOp::build(OpBuilder &builder, OperationState &state, const ParallelOperands &clauses) { MLIRContext *ctx = builder.getContext(); - ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, - clauses.ifExpr, clauses.numThreads, clauses.privateVars, - makeArrayAttr(ctx, clauses.privateSyms), - clauses.privateNeedsBarrier, clauses.procBindKind, - clauses.reductionMod, clauses.reductionVars, - makeDenseBoolArrayAttr(ctx, clauses.reductionByref), - makeArrayAttr(ctx, clauses.reductionSyms)); + ParallelOp::build( + builder, state, clauses.allocateVars, clauses.allocatorVars, + clauses.ifExpr, clauses.numThreadsDims, clauses.numThreadsValues, + clauses.numThreads, clauses.privateVars, + makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier, + clauses.procBindKind, clauses.reductionMod, clauses.reductionVars, + makeDenseBoolArrayAttr(ctx, clauses.reductionByref), + makeArrayAttr(ctx, clauses.reductionSyms)); } template <typename OpType> @@ -2568,13 +2571,40 @@ static LogicalResult verifyPrivateVarList(OpType &op) { } LogicalResult ParallelOp::verify() { + // verify num_threads clause restrictions + auto numThreadsDims = getNumThreadsDims(); + auto numThreadsValues = getNumThreadsValues(); + auto numThreads = getNumThreads(); + + // num_threads with dims modifier + if (numThreadsDims.has_value() && numThreadsValues.empty()) { + return emitError( + "num_threads dims modifier requires values to be specified"); + } + + if (numThreadsDims.has_value() && + numThreadsValues.size() != static_cast<size_t>(*numThreadsDims)) { + return emitError("num_threads dims(") + << *numThreadsDims << ") specified but " << numThreadsValues.size() + << " values provided"; + } + + // num_threads dims and number of threads cannot be used together + if (numThreadsDims.has_value() && numThreads) { + return emitError( + "num_threads dims and number of threads cannot be used together"); + } + + // verify allocate clause restrictions if (getAllocateVars().size() != getAllocatorVars().size()) return emitError( "expected equal sizes for allocate and allocator variables"); + // verify private variables restrictions if (failed(verifyPrivateVarList(*this))) return failure(); + // verify reduction variables restrictions return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(), getReductionByref()); } @@ -4595,6 +4625,41 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op, } } +//===----------------------------------------------------------------------===// +// Parser and printer for num_threads clause +//===----------------------------------------------------------------------===// +static ParseResult +parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr, + SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, + SmallVectorImpl<Type> &types, + std::optional<OpAsmParser::UnresolvedOperand> &bounds, + Type &boundsType) { + if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) { + return success(); + } + + OpAsmParser::UnresolvedOperand boundsOperand; + if (parser.parseOperand(boundsOperand) || parser.parseColon() || + parser.parseType(boundsType)) { + return failure(); + } + bounds = boundsOperand; + return success(); +} + +static void printNumThreadsClause(OpAsmPrinter &p, Operation *op, + IntegerAttr dimsAttr, OperandRange values, + TypeRange types, Value bounds, + Type boundsType) { + if (!values.empty()) { + printDimsModifierWithValues(p, dimsAttr, values, types); + } + if (bounds) { + p.printOperand(bounds); + p << " : " << boundsType; + } +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index bb882db73cbab..75431ec475954 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -30,6 +30,37 @@ func.func @num_threads_once(%n : si32) { // ----- +func.func @num_threads_dims_no_values() { + // expected-error@+1 {{num_threads dims modifier requires values to be specified}} + "omp.parallel"() ({ + omp.terminator + }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_dims = 2 : i64} : () -> () + return +} + +// ----- + +func.func @num_threads_dims_mismatch(%n : i64) { + // expected-error@+1 {{num_threads dims(2) specified but 1 values provided}} + omp.parallel num_threads(dims(2): %n : i64) { + omp.terminator + } + + return +} + +// ----- + +func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) { + // expected-error@+1 {{num_threads dims and number of threads cannot be used together}} + "omp.parallel"(%n, %n, %m) ({ + omp.terminator + }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_dims = 2 : i64} : (i64, i64, i64) -> () + return +} + +// ----- + func.func @nowait_not_allowed(%n : memref<i32>) { // expected-error@+1 {{expected '{' to begin a region}} omp.parallel nowait {} @@ -2708,7 +2739,7 @@ func.func @undefined_privatizer(%arg0: index) { // ----- func.func @undefined_privatizer(%arg0: !llvm.ptr) { // expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}} - "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({ + "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({ ^bb0(%arg2: !llvm.ptr): omp.terminator }) : (!llvm.ptr) -> () diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 89c7e5fd48bd9..3acbe010c28a5 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -73,7 +73,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) num_threads(%{{.*}} : i32) "omp.parallel"(%data_var, %data_var, %num_threads) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 1,1,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> () + }) {operandSegmentSizes = array<i32: 1,1,0,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> () // CHECK: omp.barrier omp.barrier @@ -82,22 +82,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) if(%{{.*}}) "omp.parallel"(%data_var, %data_var, %if_cond) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : (memref<i32>, memref<i32>, i1) -> () + }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : (memref<i32>, memref<i32>, i1) -> () // test without allocate // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32) "omp.parallel"(%if_cond, %num_threads) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (i1, i32) -> () + }) {operandSegmentSizes = array<i32: 0,0,1,0,1,0,0>} : (i1, i32) -> () omp.terminator - }) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> () + }) {operandSegmentSizes = array<i32: 1,1,1,0,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> () // test with multiple parameters for single variadic argument // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) "omp.parallel" (%data_var, %data_var) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, memref<i32>) -> () + }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0,0>} : (memref<i32>, memref<i32>) -> () // CHECK: omp.parallel omp.parallel { @@ -160,6 +160,11 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre omp.terminator } + // CHECK: omp.parallel num_threads(dims(2): %{{.*}}, %{{.*}} : i64) + omp.parallel num_threads(dims(2): %n_i64, %n_i64 : i64) { + omp.terminator + } + // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) omp.parallel allocate(%data_var : memref<i32> -> %data_var : memref<i32>) { omp.terminator >From 97045e6201626b5f73e5178905a9a2cefa09b9cf Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Thu, 11 Dec 2025 12:11:49 +0530 Subject: [PATCH 2/7] Mark mlir->llvmir translation for num_threads with dims as NYI --- .../Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 8a3a990e5a3fd..e66666b526069 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -3268,6 +3268,10 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, if (auto ifVar = opInst.getIfExpr()) ifCond = moduleTranslation.lookupValue(ifVar); llvm::Value *numThreads = nullptr; + // num_threads dims and values are not yet supported + assert(!opInst.getNumThreadsDims().has_value() && + opInst.getNumThreadsValues().empty() && + "Lowering of num_threads with dims modifier is NYI."); if (auto numThreadsVar = opInst.getNumThreads()) numThreads = moduleTranslation.lookupValue(numThreadsVar); auto pbKind = llvm::omp::OMP_PROC_BIND_default; @@ -6050,6 +6054,10 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, llvm_unreachable("unsupported host_eval use"); }) .Case([&](omp::ParallelOp parallelOp) { + // num_threads dims and values are not yet supported + assert(!parallelOp.getNumThreadsDims().has_value() && + parallelOp.getNumThreadsValues().empty() && + "Lowering of num_threads with dims modifier is NYI."); if (parallelOp.getNumThreads() == blockArg) numThreads = hostEvalVar; else @@ -6167,8 +6175,13 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, threadLimit = teamsOp.getThreadLimit(); } - if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) + if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) { + // num_threads dims and values are not yet supported + assert(!parallelOp.getNumThreadsDims().has_value() && + parallelOp.getNumThreadsValues().empty() && + "Lowering of num_threads with dims modifier is NYI."); numThreads = parallelOp.getNumThreads(); + } } // Handle clauses impacting the number of teams. >From 60288588459e658d9d2d1238569a19f34e932b80 Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Thu, 11 Dec 2025 17:37:52 +0530 Subject: [PATCH 3/7] few more fixes --- .../mlir/Dialect/OpenMP/OpenMPClauses.td | 33 ++++++-------- .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 4 +- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 44 +++++++++---------- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 9 ++-- mlir/test/Dialect/OpenMP/invalid.mlir | 10 ++--- 5 files changed, 45 insertions(+), 55 deletions(-) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index aedfa05da1608..3559002c6473f 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -1069,14 +1069,14 @@ class OpenMP_NumThreadsClauseSkip< > : OpenMP_Clause<traits, arguments, assemblyFormat, description, extraClassDeclaration> { let arguments = (ins - ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_dims, - Variadic<AnyInteger>:$num_threads_values, + ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims, + Variadic<AnyInteger>:$num_threads_dims_values, Optional<IntLikeType>:$num_threads ); let optAssemblyFormat = [{ `num_threads` `(` custom<NumThreadsClause>( - $num_threads_dims, $num_threads_values, type($num_threads_values), + $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values), $num_threads, type($num_threads) ) `)` }]; @@ -1086,7 +1086,7 @@ class OpenMP_NumThreadsClauseSkip< space formed by the construct on which it appears. With dims modifier: - - Uses `num_threads_dims` (dimension count) and `num_threads_values` (upper bounds list) + - Uses `num_threads_num_dims` (dimension count) and `num_threads_dims_values` (upper bounds list) - Specifies upper bounds for each dimension (all must have same type) - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)` - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)` @@ -1100,28 +1100,23 @@ class OpenMP_NumThreadsClauseSkip< let extraClassDeclaration = [{ /// Returns true if the dims modifier is explicitly present - bool hasDimsModifier() { - return getNumThreadsDims().has_value(); + bool hasNumThreadsDimsModifier() { + return getNumThreadsNumDims().has_value() && getNumThreadsNumDims().value(); } /// Returns the number of dimensions specified by dims modifier - unsigned getNumDimensions() { - if (!hasDimsModifier()) + unsigned getNumThreadsDimsCount() { + if (!hasNumThreadsDimsModifier()) return 1; - return static_cast<unsigned>(*getNumThreadsDims()); - } - - /// Returns all dimension values as an operand range - ::mlir::OperandRange getDimensionValues() { - return getNumThreadsValues(); + return static_cast<unsigned>(*getNumThreadsNumDims()); } /// Returns the value for a specific dimension index - /// Index must be less than getNumDimensions() - ::mlir::Value getDimensionValue(unsigned index) { - assert(index < getDimensionValues().size() && - "Dimension index out of bounds"); - return getDimensionValues()[index]; + /// Index must be less than getNumThreadsDimsCount() + ::mlir::Value getNumThreadsDimsValue(unsigned index) { + assert(index < getNumThreadsDimsCount() && + "Num threads dims index out of bounds"); + return getNumThreadsDimsValues()[index]; } }]; } diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index 0d5333ec2e455..ab7bded7835be 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -448,8 +448,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { /* allocate_vars = */ llvm::SmallVector<Value>{}, /* allocator_vars = */ llvm::SmallVector<Value>{}, /* if_expr = */ Value{}, - /* num_threads_dims = */ nullptr, - /* num_threads_values = */ llvm::SmallVector<Value>{}, + /* num_threads_num_dims = */ nullptr, + /* num_threads_dims_values = */ llvm::SmallVector<Value>{}, /* num_threads = */ numThreadsVar, /* private_vars = */ ValueRange(), /* private_syms = */ nullptr, diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 9664b8f59802c..54ce42f684581 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -2519,7 +2519,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, MLIRContext *ctx = builder.getContext(); ParallelOp::build( builder, state, clauses.allocateVars, clauses.allocatorVars, - clauses.ifExpr, clauses.numThreadsDims, clauses.numThreadsValues, + clauses.ifExpr, clauses.numThreadsNumDims, clauses.numThreadsDimsValues, clauses.numThreads, clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier, clauses.procBindKind, clauses.reductionMod, clauses.reductionVars, @@ -2570,30 +2570,28 @@ static LogicalResult verifyPrivateVarList(OpType &op) { return success(); } -LogicalResult ParallelOp::verify() { - // verify num_threads clause restrictions - auto numThreadsDims = getNumThreadsDims(); - auto numThreadsValues = getNumThreadsValues(); - auto numThreads = getNumThreads(); - - // num_threads with dims modifier - if (numThreadsDims.has_value() && numThreadsValues.empty()) { - return emitError( - "num_threads dims modifier requires values to be specified"); - } - - if (numThreadsDims.has_value() && - numThreadsValues.size() != static_cast<size_t>(*numThreadsDims)) { - return emitError("num_threads dims(") - << *numThreadsDims << ") specified but " << numThreadsValues.size() - << " values provided"; +// Helper: Verify num_threads clause +LogicalResult +verifyNumThreadsClause(Operation *op, + std::optional<IntegerAttr> numThreadsNumDims, + OperandRange numThreadsDimsValues, Value numThreads) { + bool hasDimsModifier = + numThreadsNumDims.has_value() && numThreadsNumDims.value(); + if (hasDimsModifier && numThreads) { + return op->emitError("num_threads with dims modifier cannot be used " + "together with number of threads"); } + if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues))) + return failure(); + return success(); +} - // num_threads dims and number of threads cannot be used together - if (numThreadsDims.has_value() && numThreads) { - return emitError( - "num_threads dims and number of threads cannot be used together"); - } +LogicalResult ParallelOp::verify() { + // verify num_threads clause restrictions + if (failed(verifyNumThreadsClause( + getOperation(), this->getNumThreadsNumDimsAttr(), + this->getNumThreadsDimsValues(), this->getNumThreads()))) + return failure(); // verify allocate clause restrictions if (getAllocateVars().size() != getAllocatorVars().size()) diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index e66666b526069..67f30383bb03a 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -3269,8 +3269,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, ifCond = moduleTranslation.lookupValue(ifVar); llvm::Value *numThreads = nullptr; // num_threads dims and values are not yet supported - assert(!opInst.getNumThreadsDims().has_value() && - opInst.getNumThreadsValues().empty() && + assert(!opInst.hasNumThreadsDimsModifier() && "Lowering of num_threads with dims modifier is NYI."); if (auto numThreadsVar = opInst.getNumThreads()) numThreads = moduleTranslation.lookupValue(numThreadsVar); @@ -6055,8 +6054,7 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, }) .Case([&](omp::ParallelOp parallelOp) { // num_threads dims and values are not yet supported - assert(!parallelOp.getNumThreadsDims().has_value() && - parallelOp.getNumThreadsValues().empty() && + assert(!parallelOp.hasNumThreadsDimsModifier() && "Lowering of num_threads with dims modifier is NYI."); if (parallelOp.getNumThreads() == blockArg) numThreads = hostEvalVar; @@ -6177,8 +6175,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) { // num_threads dims and values are not yet supported - assert(!parallelOp.getNumThreadsDims().has_value() && - parallelOp.getNumThreadsValues().empty() && + assert(!parallelOp.hasNumThreadsDimsModifier() && "Lowering of num_threads with dims modifier is NYI."); numThreads = parallelOp.getNumThreads(); } diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 75431ec475954..1c5ef785a17f9 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -31,17 +31,17 @@ func.func @num_threads_once(%n : si32) { // ----- func.func @num_threads_dims_no_values() { - // expected-error@+1 {{num_threads dims modifier requires values to be specified}} + // expected-error@+1 {{dims modifier requires values to be specified}} "omp.parallel"() ({ omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_dims = 2 : i64} : () -> () + }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> () return } // ----- func.func @num_threads_dims_mismatch(%n : i64) { - // expected-error@+1 {{num_threads dims(2) specified but 1 values provided}} + // expected-error@+1 {{dims(2) specified but 1 values provided}} omp.parallel num_threads(dims(2): %n : i64) { omp.terminator } @@ -52,10 +52,10 @@ func.func @num_threads_dims_mismatch(%n : i64) { // ----- func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) { - // expected-error@+1 {{num_threads dims and number of threads cannot be used together}} + // expected-error@+1 {{num_threads with dims modifier cannot be used together with number of threads}} "omp.parallel"(%n, %n, %m) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_dims = 2 : i64} : (i64, i64, i64) -> () + }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_num_dims = 2 : i64} : (i64, i64, i64) -> () return } >From f07a41aa54d4f27ced44bba8b013e12b4f5ba1dd Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Fri, 19 Dec 2025 12:27:38 +0530 Subject: [PATCH 4/7] Use num_threads_dims_values only --- flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 4 +- flang/lib/Lower/OpenMP/OpenMP.cpp | 15 ++--- .../mlir/Dialect/OpenMP/OpenMPClauses.td | 15 +++-- .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 5 +- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 62 ++++++++----------- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 16 ++--- mlir/test/Dialect/OpenMP/invalid.mlir | 12 ++-- mlir/test/Dialect/OpenMP/ops.mlir | 10 +-- 8 files changed, 66 insertions(+), 73 deletions(-) diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index b923e415231d6..abaeaa90f80be 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -516,8 +516,8 @@ bool ClauseProcessor::processNumThreads( mlir::omp::NumThreadsClauseOps &result) const { if (auto *clause = findUniqueClause<omp::clause::NumThreads>()) { // OMPIRBuilder expects `NUM_THREADS` clause as a `Value`. - result.numThreads = - fir::getBase(converter.genExprValue(clause->v, stmtCtx)); + result.numThreadsDimsValues.push_back( + fir::getBase(converter.genExprValue(clause->v, stmtCtx))); return true; } return false; diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 989e370870f33..bdbabc292349a 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -99,8 +99,8 @@ class HostEvalInfo { if (ops.numTeamsUpper) vars.push_back(ops.numTeamsUpper); - if (ops.numThreads) - vars.push_back(ops.numThreads); + for (auto numThreads : ops.numThreadsDimsValues) + vars.push_back(numThreads); if (ops.threadLimit) vars.push_back(ops.threadLimit); @@ -115,7 +115,8 @@ class HostEvalInfo { assert(args.size() == ops.loopLowerBounds.size() + ops.loopUpperBounds.size() + ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) + - (ops.numTeamsUpper ? 1 : 0) + (ops.numThreads ? 1 : 0) + + (ops.numTeamsUpper ? 1 : 0) + + ops.numThreadsDimsValues.size() + (ops.threadLimit ? 1 : 0) && "invalid block argument list"); int argIndex = 0; @@ -134,8 +135,8 @@ class HostEvalInfo { if (ops.numTeamsUpper) ops.numTeamsUpper = args[argIndex++]; - if (ops.numThreads) - ops.numThreads = args[argIndex++]; + for (size_t i = 0; i < ops.numThreadsDimsValues.size(); ++i) + ops.numThreadsDimsValues[i] = args[argIndex++]; if (ops.threadLimit) ops.threadLimit = args[argIndex++]; @@ -169,13 +170,13 @@ class HostEvalInfo { /// \returns whether an update was performed. If not, these clauses were not /// evaluated in the host device. bool apply(mlir::omp::ParallelOperands &clauseOps) { - if (!ops.numThreads || parallelApplied) { + if (ops.numThreadsDimsValues.empty() || parallelApplied) { parallelApplied = true; return false; } parallelApplied = true; - clauseOps.numThreads = ops.numThreads; + clauseOps.numThreadsDimsValues = ops.numThreadsDimsValues; return true; } diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index 3559002c6473f..8be7030599cc6 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -1070,14 +1070,12 @@ class OpenMP_NumThreadsClauseSkip< extraClassDeclaration> { let arguments = (ins ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims, - Variadic<AnyInteger>:$num_threads_dims_values, - Optional<IntLikeType>:$num_threads + Variadic<IntLikeType>:$num_threads_dims_values ); let optAssemblyFormat = [{ `num_threads` `(` custom<NumThreadsClause>( - $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values), - $num_threads, type($num_threads) + $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values) ) `)` }]; @@ -1092,10 +1090,9 @@ class OpenMP_NumThreadsClauseSkip< - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)` Without dims modifier: - - Uses `num_threads` - - If lower bound not specified, it defaults to upper bound value - - Format: `num_threads(bounds : type)` - - Example: `num_threads(%ub : i32)` + - The number of threads is specified by single value in `num_threads_dims_values` + - Format: `num_threads(value : type)` + - Example: `num_threads(%n : i32)` }]; let extraClassDeclaration = [{ @@ -1116,6 +1113,8 @@ class OpenMP_NumThreadsClauseSkip< ::mlir::Value getNumThreadsDimsValue(unsigned index) { assert(index < getNumThreadsDimsCount() && "Num threads dims index out of bounds"); + if(getNumThreadsDimsValues().empty()) + return nullptr; return getNumThreadsDimsValues()[index]; } }]; diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index ab7bded7835be..5d75613f9b2b6 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -438,9 +438,11 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { rewriter.eraseOp(reduce); Value numThreadsVar; + SmallVector<Value> numThreadsValues; if (numThreads > 0) { numThreadsVar = LLVM::ConstantOp::create( rewriter, loc, rewriter.getI32IntegerAttr(numThreads)); + numThreadsValues.push_back(numThreadsVar); } // Create the parallel wrapper. auto ompParallel = omp::ParallelOp::create( @@ -449,8 +451,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { /* allocator_vars = */ llvm::SmallVector<Value>{}, /* if_expr = */ Value{}, /* num_threads_num_dims = */ nullptr, - /* num_threads_dims_values = */ llvm::SmallVector<Value>{}, - /* num_threads = */ numThreadsVar, + /* num_threads_dims_values = */ numThreadsValues, /* private_vars = */ ValueRange(), /* private_syms = */ nullptr, /* private_needs_barrier = */ nullptr, diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 54ce42f684581..6911272d43f6e 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -2252,7 +2252,8 @@ LogicalResult TargetOp::verifyRegions() { if (auto parallelOp = dyn_cast<ParallelOp>(user)) { if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) && parallelOp->isAncestor(capturedOp) && - hostEvalArg == parallelOp.getNumThreads()) + llvm::is_contained(parallelOp.getNumThreadsDimsValues(), + hostEvalArg)) continue; return emitOpError() @@ -2506,7 +2507,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr, /*num_threads_dims=*/nullptr, /*num_threads_values=*/ValueRange(), - /*num_threads=*/nullptr, /*private_vars=*/ValueRange(), + /*private_vars=*/ValueRange(), /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr, /*proc_bind_kind=*/nullptr, /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(), @@ -2517,14 +2518,14 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, void ParallelOp::build(OpBuilder &builder, OperationState &state, const ParallelOperands &clauses) { MLIRContext *ctx = builder.getContext(); - ParallelOp::build( - builder, state, clauses.allocateVars, clauses.allocatorVars, - clauses.ifExpr, clauses.numThreadsNumDims, clauses.numThreadsDimsValues, - clauses.numThreads, clauses.privateVars, - makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier, - clauses.procBindKind, clauses.reductionMod, clauses.reductionVars, - makeDenseBoolArrayAttr(ctx, clauses.reductionByref), - makeArrayAttr(ctx, clauses.reductionSyms)); + ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, + clauses.ifExpr, clauses.numThreadsNumDims, + clauses.numThreadsDimsValues, clauses.privateVars, + makeArrayAttr(ctx, clauses.privateSyms), + clauses.privateNeedsBarrier, clauses.procBindKind, + clauses.reductionMod, clauses.reductionVars, + makeDenseBoolArrayAttr(ctx, clauses.reductionByref), + makeArrayAttr(ctx, clauses.reductionSyms)); } template <typename OpType> @@ -2574,13 +2575,7 @@ static LogicalResult verifyPrivateVarList(OpType &op) { LogicalResult verifyNumThreadsClause(Operation *op, std::optional<IntegerAttr> numThreadsNumDims, - OperandRange numThreadsDimsValues, Value numThreads) { - bool hasDimsModifier = - numThreadsNumDims.has_value() && numThreadsNumDims.value(); - if (hasDimsModifier && numThreads) { - return op->emitError("num_threads with dims modifier cannot be used " - "together with number of threads"); - } + OperandRange numThreadsDimsValues) { if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues))) return failure(); return success(); @@ -2588,9 +2583,9 @@ verifyNumThreadsClause(Operation *op, LogicalResult ParallelOp::verify() { // verify num_threads clause restrictions - if (failed(verifyNumThreadsClause( - getOperation(), this->getNumThreadsNumDimsAttr(), - this->getNumThreadsDimsValues(), this->getNumThreads()))) + if (failed(verifyNumThreadsClause(getOperation(), + this->getNumThreadsNumDimsAttr(), + this->getNumThreadsDimsValues()))) return failure(); // verify allocate clause restrictions @@ -4629,33 +4624,28 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op, static ParseResult parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, - SmallVectorImpl<Type> &types, - std::optional<OpAsmParser::UnresolvedOperand> &bounds, - Type &boundsType) { + SmallVectorImpl<Type> &types) { if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) { return success(); } - OpAsmParser::UnresolvedOperand boundsOperand; - if (parser.parseOperand(boundsOperand) || parser.parseColon() || - parser.parseType(boundsType)) { + // Without dims modifier: value : type + OpAsmParser::UnresolvedOperand singleValue; + Type singleType; + if (parser.parseOperand(singleValue) || parser.parseColon() || + parser.parseType(singleType)) { return failure(); } - bounds = boundsOperand; + values.push_back(singleValue); + types.push_back(singleType); return success(); } static void printNumThreadsClause(OpAsmPrinter &p, Operation *op, IntegerAttr dimsAttr, OperandRange values, - TypeRange types, Value bounds, - Type boundsType) { - if (!values.empty()) { - printDimsModifierWithValues(p, dimsAttr, values, types); - } - if (bounds) { - p.printOperand(bounds); - p << " : " << boundsType; - } + TypeRange types) { + // Multidimensional: dims(N): values : type + printDimsModifierWithValues(p, dimsAttr, values, types); } #define GET_ATTRDEF_CLASSES diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 67f30383bb03a..da44dda0a1230 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -3270,8 +3270,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, llvm::Value *numThreads = nullptr; // num_threads dims and values are not yet supported assert(!opInst.hasNumThreadsDimsModifier() && - "Lowering of num_threads with dims modifier is NYI."); - if (auto numThreadsVar = opInst.getNumThreads()) + "Lowering of num_threads with dims modifier is not yet implemented."); + if (auto numThreadsVar = opInst.getNumThreadsDimsValue(0)) numThreads = moduleTranslation.lookupValue(numThreadsVar); auto pbKind = llvm::omp::OMP_PROC_BIND_default; if (auto bind = opInst.getProcBindKind()) @@ -6055,8 +6055,9 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, .Case([&](omp::ParallelOp parallelOp) { // num_threads dims and values are not yet supported assert(!parallelOp.hasNumThreadsDimsModifier() && - "Lowering of num_threads with dims modifier is NYI."); - if (parallelOp.getNumThreads() == blockArg) + "Lowering of num_threads with dims modifier is not yet " + "implemented."); + if (parallelOp.getNumThreadsDimsValue(0) == blockArg) numThreads = hostEvalVar; else llvm_unreachable("unsupported host_eval use"); @@ -6175,9 +6176,10 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) { // num_threads dims and values are not yet supported - assert(!parallelOp.hasNumThreadsDimsModifier() && - "Lowering of num_threads with dims modifier is NYI."); - numThreads = parallelOp.getNumThreads(); + assert( + !parallelOp.hasNumThreadsDimsModifier() && + "Lowering of num_threads with dims modifier is not yet implemented."); + numThreads = parallelOp.getNumThreadsDimsValue(0); } } diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 1c5ef785a17f9..8a5e64b1a98ca 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -34,7 +34,7 @@ func.func @num_threads_dims_no_values() { // expected-error@+1 {{dims modifier requires values to be specified}} "omp.parallel"() ({ omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> () + }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> () return } @@ -51,11 +51,11 @@ func.func @num_threads_dims_mismatch(%n : i64) { // ----- -func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) { - // expected-error@+1 {{num_threads with dims modifier cannot be used together with number of threads}} - "omp.parallel"(%n, %n, %m) ({ +func.func @num_threads_multiple_values_without_dims(%n : i64, %m: i64) { + // expected-error@+1 {{dims values can only be specified with dims modifier}} + "omp.parallel"(%n, %m) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_num_dims = 2 : i64} : (i64, i64, i64) -> () + }) {operandSegmentSizes = array<i32: 0,0,0,2,0,0>} : (i64, i64) -> () return } @@ -2739,7 +2739,7 @@ func.func @undefined_privatizer(%arg0: index) { // ----- func.func @undefined_privatizer(%arg0: !llvm.ptr) { // expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}} - "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({ + "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({ ^bb0(%arg2: !llvm.ptr): omp.terminator }) : (!llvm.ptr) -> () diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 3acbe010c28a5..4c57b8aea0b48 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -73,7 +73,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) num_threads(%{{.*}} : i32) "omp.parallel"(%data_var, %data_var, %num_threads) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 1,1,0,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> () + }) {operandSegmentSizes = array<i32: 1,1,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> () // CHECK: omp.barrier omp.barrier @@ -82,22 +82,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) if(%{{.*}}) "omp.parallel"(%data_var, %data_var, %if_cond) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : (memref<i32>, memref<i32>, i1) -> () + }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : (memref<i32>, memref<i32>, i1) -> () // test without allocate // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32) "omp.parallel"(%if_cond, %num_threads) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,1,0,1,0,0>} : (i1, i32) -> () + }) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (i1, i32) -> () omp.terminator - }) {operandSegmentSizes = array<i32: 1,1,1,0,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> () + }) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> () // test with multiple parameters for single variadic argument // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) "omp.parallel" (%data_var, %data_var) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0,0>} : (memref<i32>, memref<i32>) -> () + }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, memref<i32>) -> () // CHECK: omp.parallel omp.parallel { >From 038f9f4b3cfd4664f4df95e141178c6289194ac4 Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Wed, 14 Jan 2026 12:07:56 +0530 Subject: [PATCH 5/7] fix adding numThreadsNumDims to ParallelOperands apply method --- flang/lib/Lower/OpenMP/OpenMP.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index bdbabc292349a..5ca228e218c37 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -177,6 +177,7 @@ class HostEvalInfo { parallelApplied = true; clauseOps.numThreadsDimsValues = ops.numThreadsDimsValues; + clauseOps.numThreadsNumDims = ops.numThreadsNumDims; return true; } >From 12c4749a7dc638ea4f22f2e1dd9cf9fd987f5123 Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Fri, 16 Jan 2026 12:32:56 +0530 Subject: [PATCH 6/7] Remove dims(N) syntax and use list of vals for num_threads --- flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 2 +- flang/lib/Lower/OpenMP/OpenMP.cpp | 14 +++-- .../mlir/Dialect/OpenMP/OpenMPClauses.td | 53 +++++++++---------- .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 3 +- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 51 +++++------------- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 26 ++++----- mlir/test/Dialect/OpenMP/invalid.mlir | 31 ----------- mlir/test/Dialect/OpenMP/ops.mlir | 11 +++- mlir/test/Target/LLVMIR/openmp-todo.mlir | 11 ++++ 9 files changed, 76 insertions(+), 126 deletions(-) diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index abaeaa90f80be..90825a3653016 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -516,7 +516,7 @@ bool ClauseProcessor::processNumThreads( mlir::omp::NumThreadsClauseOps &result) const { if (auto *clause = findUniqueClause<omp::clause::NumThreads>()) { // OMPIRBuilder expects `NUM_THREADS` clause as a `Value`. - result.numThreadsDimsValues.push_back( + result.numThreadsVals.push_back( fir::getBase(converter.genExprValue(clause->v, stmtCtx))); return true; } diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 5ca228e218c37..c9271925580cd 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -99,7 +99,7 @@ class HostEvalInfo { if (ops.numTeamsUpper) vars.push_back(ops.numTeamsUpper); - for (auto numThreads : ops.numThreadsDimsValues) + for (auto numThreads : ops.numThreadsVals) vars.push_back(numThreads); if (ops.threadLimit) @@ -115,8 +115,7 @@ class HostEvalInfo { assert(args.size() == ops.loopLowerBounds.size() + ops.loopUpperBounds.size() + ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) + - (ops.numTeamsUpper ? 1 : 0) + - ops.numThreadsDimsValues.size() + + (ops.numTeamsUpper ? 1 : 0) + ops.numThreadsVals.size() + (ops.threadLimit ? 1 : 0) && "invalid block argument list"); int argIndex = 0; @@ -135,8 +134,8 @@ class HostEvalInfo { if (ops.numTeamsUpper) ops.numTeamsUpper = args[argIndex++]; - for (size_t i = 0; i < ops.numThreadsDimsValues.size(); ++i) - ops.numThreadsDimsValues[i] = args[argIndex++]; + for (size_t i = 0; i < ops.numThreadsVals.size(); ++i) + ops.numThreadsVals[i] = args[argIndex++]; if (ops.threadLimit) ops.threadLimit = args[argIndex++]; @@ -170,14 +169,13 @@ class HostEvalInfo { /// \returns whether an update was performed. If not, these clauses were not /// evaluated in the host device. bool apply(mlir::omp::ParallelOperands &clauseOps) { - if (ops.numThreadsDimsValues.empty() || parallelApplied) { + if (ops.numThreadsVals.empty() || parallelApplied) { parallelApplied = true; return false; } parallelApplied = true; - clauseOps.numThreadsDimsValues = ops.numThreadsDimsValues; - clauseOps.numThreadsNumDims = ops.numThreadsNumDims; + clauseOps.numThreadsVals = ops.numThreadsVals; return true; } diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index 8be7030599cc6..90bff92fbc826 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -1069,53 +1069,48 @@ class OpenMP_NumThreadsClauseSkip< > : OpenMP_Clause<traits, arguments, assemblyFormat, description, extraClassDeclaration> { let arguments = (ins - ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims, - Variadic<IntLikeType>:$num_threads_dims_values + Variadic<IntLikeType>:$num_threads_vals ); let optAssemblyFormat = [{ `num_threads` `(` custom<NumThreadsClause>( - $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values) + $num_threads_vals, type($num_threads_vals) ) `)` }]; let description = [{ - num_threads clause specifies the desired number of threads in the team - space formed by the construct on which it appears. - - With dims modifier: - - Uses `num_threads_num_dims` (dimension count) and `num_threads_dims_values` (upper bounds list) - - Specifies upper bounds for each dimension (all must have same type) - - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)` - - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)` - - Without dims modifier: - - The number of threads is specified by single value in `num_threads_dims_values` - - Format: `num_threads(value : type)` + The `num_threads` clause specifies the number of threads. + + Multi-dimensional format (dims modifier): + - Multiple values can be specified for multi-dimensional thread counts. + - The number of dimensions is derived from the number of values. + - Values can have different integer types. + - Format: `num_threads(%v1, %v2, ... : type1, type2, ...)` + - Example: `num_threads(%n, %m : i32, i64)` + + Single value format: + - A single value specifies the number of threads. + - Format: `num_threads(%value : type)` - Example: `num_threads(%n : i32)` }]; let extraClassDeclaration = [{ - /// Returns true if the dims modifier is explicitly present - bool hasNumThreadsDimsModifier() { - return getNumThreadsNumDims().has_value() && getNumThreadsNumDims().value(); + /// Returns true if using multi-dimensional values (more than one value) + bool hasNumThreadsMultiDim() { + return getNumThreadsVals().size() > 1; } - /// Returns the number of dimensions specified by dims modifier + /// Returns the number of dimensions specified for num_threads unsigned getNumThreadsDimsCount() { - if (!hasNumThreadsDimsModifier()) - return 1; - return static_cast<unsigned>(*getNumThreadsNumDims()); + return getNumThreadsVals().size(); } /// Returns the value for a specific dimension index - /// Index must be less than getNumThreadsDimsCount() - ::mlir::Value getNumThreadsDimsValue(unsigned index) { - assert(index < getNumThreadsDimsCount() && - "Num threads dims index out of bounds"); - if(getNumThreadsDimsValues().empty()) - return nullptr; - return getNumThreadsDimsValues()[index]; + /// Index must be less than getNumThreadsVals().size() + ::mlir::Value getNumThreadsVal(unsigned index) { + assert(index < getNumThreadsVals().size() && + "Num threads index out of bounds"); + return getNumThreadsVals()[index]; } }]; } diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index 5d75613f9b2b6..6ba2155c7840f 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -450,8 +450,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { /* allocate_vars = */ llvm::SmallVector<Value>{}, /* allocator_vars = */ llvm::SmallVector<Value>{}, /* if_expr = */ Value{}, - /* num_threads_num_dims = */ nullptr, - /* num_threads_dims_values = */ numThreadsValues, + /* num_threads_vals = */ numThreadsValues, /* private_vars = */ ValueRange(), /* private_syms = */ nullptr, /* private_needs_barrier = */ nullptr, diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 6911272d43f6e..bc7647d129f60 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -2252,8 +2252,7 @@ LogicalResult TargetOp::verifyRegions() { if (auto parallelOp = dyn_cast<ParallelOp>(user)) { if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) && parallelOp->isAncestor(capturedOp) && - llvm::is_contained(parallelOp.getNumThreadsDimsValues(), - hostEvalArg)) + llvm::is_contained(parallelOp.getNumThreadsVals(), hostEvalArg)) continue; return emitOpError() @@ -2505,8 +2504,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, ArrayRef<NamedAttribute> attributes) { ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(), /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr, - /*num_threads_dims=*/nullptr, - /*num_threads_values=*/ValueRange(), + /*num_threads_vals=*/ValueRange(), /*private_vars=*/ValueRange(), /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr, /*proc_bind_kind=*/nullptr, @@ -2519,8 +2517,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, const ParallelOperands &clauses) { MLIRContext *ctx = builder.getContext(); ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, - clauses.ifExpr, clauses.numThreadsNumDims, - clauses.numThreadsDimsValues, clauses.privateVars, + clauses.ifExpr, clauses.numThreadsVals, clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier, clauses.procBindKind, clauses.reductionMod, clauses.reductionVars, @@ -2571,23 +2568,7 @@ static LogicalResult verifyPrivateVarList(OpType &op) { return success(); } -// Helper: Verify num_threads clause -LogicalResult -verifyNumThreadsClause(Operation *op, - std::optional<IntegerAttr> numThreadsNumDims, - OperandRange numThreadsDimsValues) { - if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues))) - return failure(); - return success(); -} - LogicalResult ParallelOp::verify() { - // verify num_threads clause restrictions - if (failed(verifyNumThreadsClause(getOperation(), - this->getNumThreadsNumDimsAttr(), - this->getNumThreadsDimsValues()))) - return failure(); - // verify allocate clause restrictions if (getAllocateVars().size() != getAllocatorVars().size()) return emitError( @@ -4622,30 +4603,24 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op, // Parser and printer for num_threads clause //===----------------------------------------------------------------------===// static ParseResult -parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr, +parseNumThreadsClause(OpAsmParser &parser, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, SmallVectorImpl<Type> &types) { - if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) { - return success(); - } - - // Without dims modifier: value : type - OpAsmParser::UnresolvedOperand singleValue; - Type singleType; - if (parser.parseOperand(singleValue) || parser.parseColon() || - parser.parseType(singleType)) { + // Parse comma-separated list of values with their types + // Format: %v1, %v2, ... : type1, type2, ... + if (parser.parseOperandList(values) || parser.parseColon() || + parser.parseTypeList(types)) { return failure(); } - values.push_back(singleValue); - types.push_back(singleType); return success(); } static void printNumThreadsClause(OpAsmPrinter &p, Operation *op, - IntegerAttr dimsAttr, OperandRange values, - TypeRange types) { - // Multidimensional: dims(N): values : type - printDimsModifierWithValues(p, dimsAttr, values, types); + OperandRange values, TypeRange types) { + // Print values with their types + llvm::interleaveComma(values, p, [&](Value v) { p << v; }); + p << " : "; + llvm::interleaveComma(types, p, [&](Type t) { p << t; }); } #define GET_ATTRDEF_CLASSES diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index da44dda0a1230..2fd3da1b5b30a 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -380,6 +380,10 @@ static LogicalResult checkImplementationStatus(Operation &op) { if (op.hasNumTeamsMultiDim()) result = todo("num_teams with multi-dimensional values"); }; + auto checkNumThreadsMultiDim = [&todo](auto op, LogicalResult &result) { + if (op.hasNumThreadsMultiDim()) + result = todo("num_threads with multi-dimensional values"); + }; LogicalResult result = success(); llvm::TypeSwitch<Operation &>(op) @@ -431,6 +435,7 @@ static LogicalResult checkImplementationStatus(Operation &op) { .Case([&](omp::ParallelOp op) { checkAllocate(op, result); checkReduction(op, result); + checkNumThreadsMultiDim(op, result); }) .Case([&](omp::SimdOp op) { checkReduction(op, result); }) .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp, @@ -3268,11 +3273,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, if (auto ifVar = opInst.getIfExpr()) ifCond = moduleTranslation.lookupValue(ifVar); llvm::Value *numThreads = nullptr; - // num_threads dims and values are not yet supported - assert(!opInst.hasNumThreadsDimsModifier() && - "Lowering of num_threads with dims modifier is not yet implemented."); - if (auto numThreadsVar = opInst.getNumThreadsDimsValue(0)) - numThreads = moduleTranslation.lookupValue(numThreadsVar); + if (!opInst.getNumThreadsVals().empty()) + numThreads = moduleTranslation.lookupValue(opInst.getNumThreadsVal(0)); auto pbKind = llvm::omp::OMP_PROC_BIND_default; if (auto bind = opInst.getProcBindKind()) pbKind = getProcBindKind(*bind); @@ -6053,11 +6055,8 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, llvm_unreachable("unsupported host_eval use"); }) .Case([&](omp::ParallelOp parallelOp) { - // num_threads dims and values are not yet supported - assert(!parallelOp.hasNumThreadsDimsModifier() && - "Lowering of num_threads with dims modifier is not yet " - "implemented."); - if (parallelOp.getNumThreadsDimsValue(0) == blockArg) + if (!parallelOp.getNumThreadsVals().empty() && + parallelOp.getNumThreadsVal(0) == blockArg) numThreads = hostEvalVar; else llvm_unreachable("unsupported host_eval use"); @@ -6175,11 +6174,8 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, } if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) { - // num_threads dims and values are not yet supported - assert( - !parallelOp.hasNumThreadsDimsModifier() && - "Lowering of num_threads with dims modifier is not yet implemented."); - numThreads = parallelOp.getNumThreadsDimsValue(0); + if (!parallelOp.getNumThreadsVals().empty()) + numThreads = parallelOp.getNumThreadsVal(0); } } diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 8a5e64b1a98ca..bb882db73cbab 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -30,37 +30,6 @@ func.func @num_threads_once(%n : si32) { // ----- -func.func @num_threads_dims_no_values() { - // expected-error@+1 {{dims modifier requires values to be specified}} - "omp.parallel"() ({ - omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> () - return -} - -// ----- - -func.func @num_threads_dims_mismatch(%n : i64) { - // expected-error@+1 {{dims(2) specified but 1 values provided}} - omp.parallel num_threads(dims(2): %n : i64) { - omp.terminator - } - - return -} - -// ----- - -func.func @num_threads_multiple_values_without_dims(%n : i64, %m: i64) { - // expected-error@+1 {{dims values can only be specified with dims modifier}} - "omp.parallel"(%n, %m) ({ - omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,0,2,0,0>} : (i64, i64) -> () - return -} - -// ----- - func.func @nowait_not_allowed(%n : memref<i32>) { // expected-error@+1 {{expected '{' to begin a region}} omp.parallel nowait {} diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 4c57b8aea0b48..67f93869d4be7 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -160,8 +160,15 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre omp.terminator } - // CHECK: omp.parallel num_threads(dims(2): %{{.*}}, %{{.*}} : i64) - omp.parallel num_threads(dims(2): %n_i64, %n_i64 : i64) { + // CHECK: omp.parallel num_threads(%{{.*}}, %{{.*}} : i64, i64) + omp.parallel num_threads(%n_i64, %n_i64 : i64, i64) { + omp.terminator + } + + %n_i16 = arith.constant 8 : i16 + // Test num_threads with mixed types. + // CHECK: omp.parallel num_threads(%{{.*}}, %{{.*}}, %{{.*}} : i32, i64, i16) + omp.parallel num_threads(%num_threads, %n_i64, %n_i16 : i32, i64, i16) { omp.terminator } diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir index 3681ce38bd523..fd218e91d0b46 100644 --- a/mlir/test/Target/LLVMIR/openmp-todo.mlir +++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir @@ -452,6 +452,17 @@ llvm.func @teams_num_teams_multi_dim(%lb : i32, %ub : i32) { // ----- +llvm.func @parallel_num_threads_multi_dim(%lb : i32, %ub : i32) { + // expected-error@below {{not yet implemented: Unhandled clause num_threads with multi-dimensional values in omp.parallel operation}} + // expected-error@below {{LLVM Translation failed for operation: omp.parallel}} + omp.parallel num_threads(%lb, %ub : i32, i32) { + omp.terminator + } + llvm.return +} + +// ----- + llvm.func @wsloop_allocate(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { // expected-error@below {{not yet implemented: Unhandled clause allocate in omp.wsloop operation}} // expected-error@below {{LLVM Translation failed for operation: omp.wsloop}} >From 1033cc66ab5617df178499b6138a6f00a7da18f5 Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Sat, 17 Jan 2026 10:37:09 +0530 Subject: [PATCH 7/7] remove custom parser printer for num_threads --- .../mlir/Dialect/OpenMP/OpenMPClauses.td | 10 ++++---- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 24 ------------------- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 10 ++++---- 3 files changed, 9 insertions(+), 35 deletions(-) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index 90bff92fbc826..7d0e1e3f91af4 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -1073,9 +1073,7 @@ class OpenMP_NumThreadsClauseSkip< ); let optAssemblyFormat = [{ - `num_threads` `(` custom<NumThreadsClause>( - $num_threads_vals, type($num_threads_vals) - ) `)` + `num_threads` `(` $num_threads_vals `:` type($num_threads_vals) `)` }]; let description = [{ @@ -1107,10 +1105,10 @@ class OpenMP_NumThreadsClauseSkip< /// Returns the value for a specific dimension index /// Index must be less than getNumThreadsVals().size() - ::mlir::Value getNumThreadsVal(unsigned index) { - assert(index < getNumThreadsVals().size() && + ::mlir::Value getNumThreads(unsigned dim = 0) { + assert(dim < getNumThreadsDimsCount() && "Num threads index out of bounds"); - return getNumThreadsVals()[index]; + return getNumThreadsVals()[dim]; } }]; } diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index bc7647d129f60..ab1038c755f7a 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -4599,30 +4599,6 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op, } } -//===----------------------------------------------------------------------===// -// Parser and printer for num_threads clause -//===----------------------------------------------------------------------===// -static ParseResult -parseNumThreadsClause(OpAsmParser &parser, - SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, - SmallVectorImpl<Type> &types) { - // Parse comma-separated list of values with their types - // Format: %v1, %v2, ... : type1, type2, ... - if (parser.parseOperandList(values) || parser.parseColon() || - parser.parseTypeList(types)) { - return failure(); - } - return success(); -} - -static void printNumThreadsClause(OpAsmPrinter &p, Operation *op, - OperandRange values, TypeRange types) { - // Print values with their types - llvm::interleaveComma(values, p, [&](Value v) { p << v; }); - p << " : "; - llvm::interleaveComma(types, p, [&](Type t) { p << t; }); -} - #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 2fd3da1b5b30a..b92ec9332d43a 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -380,7 +380,7 @@ static LogicalResult checkImplementationStatus(Operation &op) { if (op.hasNumTeamsMultiDim()) result = todo("num_teams with multi-dimensional values"); }; - auto checkNumThreadsMultiDim = [&todo](auto op, LogicalResult &result) { + auto checkNumThreads = [&todo](auto op, LogicalResult &result) { if (op.hasNumThreadsMultiDim()) result = todo("num_threads with multi-dimensional values"); }; @@ -435,7 +435,7 @@ static LogicalResult checkImplementationStatus(Operation &op) { .Case([&](omp::ParallelOp op) { checkAllocate(op, result); checkReduction(op, result); - checkNumThreadsMultiDim(op, result); + checkNumThreads(op, result); }) .Case([&](omp::SimdOp op) { checkReduction(op, result); }) .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp, @@ -3274,7 +3274,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, ifCond = moduleTranslation.lookupValue(ifVar); llvm::Value *numThreads = nullptr; if (!opInst.getNumThreadsVals().empty()) - numThreads = moduleTranslation.lookupValue(opInst.getNumThreadsVal(0)); + numThreads = moduleTranslation.lookupValue(opInst.getNumThreads(0)); auto pbKind = llvm::omp::OMP_PROC_BIND_default; if (auto bind = opInst.getProcBindKind()) pbKind = getProcBindKind(*bind); @@ -6056,7 +6056,7 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, }) .Case([&](omp::ParallelOp parallelOp) { if (!parallelOp.getNumThreadsVals().empty() && - parallelOp.getNumThreadsVal(0) == blockArg) + parallelOp.getNumThreads(0) == blockArg) numThreads = hostEvalVar; else llvm_unreachable("unsupported host_eval use"); @@ -6175,7 +6175,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) { if (!parallelOp.getNumThreadsVals().empty()) - numThreads = parallelOp.getNumThreadsVal(0); + numThreads = parallelOp.getNumThreads(0); } } _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
