https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/171767

>From 13ff8e42b2d4df4be49f3a50400c54f3d5f8f219 Mon Sep 17 00:00:00 2001
From: skc7 <[email protected]>
Date: Thu, 11 Dec 2025 11:56:58 +0530
Subject: [PATCH 1/4] [OpenMP][MLIR] Add num_threads clause with dims modifier
 support

---
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      | 50 +++++++++++-
 .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp    |  2 +
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 79 +++++++++++++++++--
 mlir/test/Dialect/OpenMP/invalid.mlir         | 33 +++++++-
 mlir/test/Dialect/OpenMP/ops.mlir             | 15 ++--
 5 files changed, 163 insertions(+), 16 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td 
b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index b949e2629a095..56cfd016ef52b 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1069,16 +1069,60 @@ class OpenMP_NumThreadsClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let arguments = (ins
+    ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_dims,
+    Variadic<AnyInteger>:$num_threads_values,
     Optional<IntLikeType>:$num_threads
   );
 
   let optAssemblyFormat = [{
-    `num_threads` `(` $num_threads `:` type($num_threads) `)`
+    `num_threads` `(` custom<NumThreadsClause>(
+      $num_threads_dims, $num_threads_values, type($num_threads_values),
+      $num_threads, type($num_threads)
+    ) `)`
   }];
 
   let description = [{
-    The optional `num_threads` parameter specifies the number of threads which
-    should be used to execute the parallel region.
+    num_threads clause specifies the desired number of threads in the team
+    space formed by the construct on which it appears.
+
+    With dims modifier:
+    - Uses `num_threads_dims` (dimension count) and `num_threads_values` 
(upper bounds list)
+    - Specifies upper bounds for each dimension (all must have same type)
+    - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : 
type)`
+    - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
+
+    Without dims modifier:
+    - Uses `num_threads`
+    - If lower bound not specified, it defaults to upper bound value
+    - Format: `num_threads(bounds : type)`
+    - Example: `num_threads(%ub : i32)`
+  }];
+
+  let extraClassDeclaration = [{
+    /// Returns true if the dims modifier is explicitly present
+    bool hasDimsModifier() {
+      return getNumThreadsDims().has_value();
+    }
+
+    /// Returns the number of dimensions specified by dims modifier
+    unsigned getNumDimensions() {
+      if (!hasDimsModifier())
+        return 1;
+      return static_cast<unsigned>(*getNumThreadsDims());
+    }
+
+    /// Returns all dimension values as an operand range
+    ::mlir::OperandRange getDimensionValues() {
+      return getNumThreadsValues();
+    }
+
+    /// Returns the value for a specific dimension index
+    /// Index must be less than getNumDimensions()
+    ::mlir::Value getDimensionValue(unsigned index) {
+      assert(index < getDimensionValues().size() &&
+             "Dimension index out of bounds");
+      return getDimensionValues()[index];
+    }
   }];
 }
 
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp 
b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 6423d49859c97..0d5333ec2e455 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -448,6 +448,8 @@ struct ParallelOpLowering : public 
OpRewritePattern<scf::ParallelOp> {
         /* allocate_vars = */ llvm::SmallVector<Value>{},
         /* allocator_vars = */ llvm::SmallVector<Value>{},
         /* if_expr = */ Value{},
+        /* num_threads_dims = */ nullptr,
+        /* num_threads_values = */ llvm::SmallVector<Value>{},
         /* num_threads = */ numThreadsVar,
         /* private_vars = */ ValueRange(),
         /* private_syms = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp 
b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index f9dcb48c7f08b..f1f3e5c0b691b 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2538,6 +2538,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState 
&state,
                        ArrayRef<NamedAttribute> attributes) {
   ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
                     /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
+                    /*num_threads_dims=*/nullptr,
+                    /*num_threads_values=*/ValueRange(),
                     /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
                     /*private_syms=*/nullptr, 
/*private_needs_barrier=*/nullptr,
                     /*proc_bind_kind=*/nullptr,
@@ -2549,13 +2551,14 @@ void ParallelOp::build(OpBuilder &builder, 
OperationState &state,
 void ParallelOp::build(OpBuilder &builder, OperationState &state,
                        const ParallelOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
-  ParallelOp::build(builder, state, clauses.allocateVars, 
clauses.allocatorVars,
-                    clauses.ifExpr, clauses.numThreads, clauses.privateVars,
-                    makeArrayAttr(ctx, clauses.privateSyms),
-                    clauses.privateNeedsBarrier, clauses.procBindKind,
-                    clauses.reductionMod, clauses.reductionVars,
-                    makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
-                    makeArrayAttr(ctx, clauses.reductionSyms));
+  ParallelOp::build(
+      builder, state, clauses.allocateVars, clauses.allocatorVars,
+      clauses.ifExpr, clauses.numThreadsDims, clauses.numThreadsValues,
+      clauses.numThreads, clauses.privateVars,
+      makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
+      clauses.procBindKind, clauses.reductionMod, clauses.reductionVars,
+      makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+      makeArrayAttr(ctx, clauses.reductionSyms));
 }
 
 template <typename OpType>
@@ -2602,13 +2605,40 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
 }
 
 LogicalResult ParallelOp::verify() {
+  // verify num_threads clause restrictions
+  auto numThreadsDims = getNumThreadsDims();
+  auto numThreadsValues = getNumThreadsValues();
+  auto numThreads = getNumThreads();
+
+  // num_threads with dims modifier
+  if (numThreadsDims.has_value() && numThreadsValues.empty()) {
+    return emitError(
+        "num_threads dims modifier requires values to be specified");
+  }
+
+  if (numThreadsDims.has_value() &&
+      numThreadsValues.size() != static_cast<size_t>(*numThreadsDims)) {
+    return emitError("num_threads dims(")
+           << *numThreadsDims << ") specified but " << numThreadsValues.size()
+           << " values provided";
+  }
+
+  // num_threads dims and number of threads cannot be used together
+  if (numThreadsDims.has_value() && numThreads) {
+    return emitError(
+        "num_threads dims and number of threads cannot be used together");
+  }
+
+  // verify allocate clause restrictions
   if (getAllocateVars().size() != getAllocatorVars().size())
     return emitError(
         "expected equal sizes for allocate and allocator variables");
 
+  // verify private variables restrictions
   if (failed(verifyPrivateVarList(*this)))
     return failure();
 
+  // verify reduction variables restrictions
   return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
                                 getReductionByref());
 }
@@ -4652,6 +4682,41 @@ static void printNumTeamsClause(OpAsmPrinter &p, 
Operation *op,
   }
 }
 
+//===----------------------------------------------------------------------===//
+// Parser and printer for num_threads clause
+//===----------------------------------------------------------------------===//
+static ParseResult
+parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
+                      SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+                      SmallVectorImpl<Type> &types,
+                      std::optional<OpAsmParser::UnresolvedOperand> &bounds,
+                      Type &boundsType) {
+  if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) 
{
+    return success();
+  }
+
+  OpAsmParser::UnresolvedOperand boundsOperand;
+  if (parser.parseOperand(boundsOperand) || parser.parseColon() ||
+      parser.parseType(boundsType)) {
+    return failure();
+  }
+  bounds = boundsOperand;
+  return success();
+}
+
+static void printNumThreadsClause(OpAsmPrinter &p, Operation *op,
+                                  IntegerAttr dimsAttr, OperandRange values,
+                                  TypeRange types, Value bounds,
+                                  Type boundsType) {
+  if (!values.empty()) {
+    printDimsModifierWithValues(p, dimsAttr, values, types);
+  }
+  if (bounds) {
+    p.printOperand(bounds);
+    p << " : " << boundsType;
+  }
+}
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
 
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir 
b/mlir/test/Dialect/OpenMP/invalid.mlir
index dd367aba8da27..9e2e5722aab9f 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -30,6 +30,37 @@ func.func @num_threads_once(%n : si32) {
 
 // -----
 
+func.func @num_threads_dims_no_values() {
+  // expected-error@+1 {{num_threads dims modifier requires values to be 
specified}}
+  "omp.parallel"() ({
+    omp.terminator
+  }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_dims = 2 : 
i64} : () -> ()
+  return
+}
+
+// -----
+
+func.func @num_threads_dims_mismatch(%n : i64) {
+  // expected-error@+1 {{num_threads dims(2) specified but 1 values provided}}
+  omp.parallel num_threads(dims(2): %n : i64) {
+    omp.terminator
+  }
+
+  return
+}
+
+// -----
+
+func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) {
+  // expected-error@+1 {{num_threads dims and number of threads cannot be used 
together}}
+  "omp.parallel"(%n, %n, %m) ({
+    omp.terminator
+  }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_dims = 2 : 
i64} : (i64, i64, i64) -> ()
+  return
+}
+
+// -----
+
 func.func @nowait_not_allowed(%n : memref<i32>) {
   // expected-error@+1 {{expected '{' to begin a region}}
   omp.parallel nowait {}
@@ -2766,7 +2797,7 @@ func.func @undefined_privatizer(%arg0: index) {
 // -----
 func.func @undefined_privatizer(%arg0: !llvm.ptr) {
   // expected-error @below {{inconsistent number of private variables and 
privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
-  "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 1, 0>, 
private_syms = [@x.privatizer, @y.privatizer]}> ({
+  "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 
0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
     ^bb0(%arg2: !llvm.ptr):
       omp.terminator
     }) : (!llvm.ptr) -> ()
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir 
b/mlir/test/Dialect/OpenMP/ops.mlir
index 3633a4be1eb62..585c9483c08a9 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -73,7 +73,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : 
i1, %num_threads : i
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : 
memref<i32>) num_threads(%{{.*}} : i32)
     "omp.parallel"(%data_var, %data_var, %num_threads) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 1,1,0,1,0,0>} : (memref<i32>, 
memref<i32>, i32) -> ()
+    }) {operandSegmentSizes = array<i32: 1,1,0,0,1,0,0>} : (memref<i32>, 
memref<i32>, i32) -> ()
 
   // CHECK: omp.barrier
     omp.barrier
@@ -82,22 +82,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : 
i1, %num_threads : i
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : 
memref<i32>) if(%{{.*}})
     "omp.parallel"(%data_var, %data_var, %if_cond) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : (memref<i32>, 
memref<i32>, i1) -> ()
+    }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : (memref<i32>, 
memref<i32>, i1) -> ()
 
   // test without allocate
   // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
     "omp.parallel"(%if_cond, %num_threads) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (i1, i32) -> ()
+    }) {operandSegmentSizes = array<i32: 0,0,1,0,1,0,0>} : (i1, i32) -> ()
 
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_kind = 
#omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
+  }) {operandSegmentSizes = array<i32: 1,1,1,0,1,0,0>, proc_bind_kind = 
#omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
 
   // test with multiple parameters for single variadic argument
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : 
memref<i32>)
   "omp.parallel" (%data_var, %data_var) ({
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, 
memref<i32>) -> ()
+  }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0,0>} : (memref<i32>, 
memref<i32>) -> ()
 
   // CHECK: omp.parallel
   omp.parallel {
@@ -160,6 +160,11 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, 
%if_cond : i1, %num_thre
    omp.terminator
  }
 
+ // CHECK: omp.parallel num_threads(dims(2): %{{.*}}, %{{.*}} : i64)
+ omp.parallel num_threads(dims(2): %n_i64, %n_i64 : i64) {
+   omp.terminator
+ }
+
  // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : 
memref<i32>)
  omp.parallel allocate(%data_var : memref<i32> -> %data_var : memref<i32>) {
    omp.terminator

>From 16ab3e92066d615fcc877e886ca1d5cfbeba3d9f Mon Sep 17 00:00:00 2001
From: skc7 <[email protected]>
Date: Thu, 11 Dec 2025 12:11:49 +0530
Subject: [PATCH 2/4] Mark mlir->llvmir translation for num_threads with dims
 as NYI

---
 .../Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp  | 15 ++++++++++++++-
 1 file changed, 14 insertions(+), 1 deletion(-)

diff --git 
a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp 
b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index cd5986f7e9ef9..cdeb1872ca92a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2888,6 +2888,10 @@ convertOmpParallel(omp::ParallelOp opInst, 
llvm::IRBuilderBase &builder,
   if (auto ifVar = opInst.getIfExpr())
     ifCond = moduleTranslation.lookupValue(ifVar);
   llvm::Value *numThreads = nullptr;
+  // num_threads dims and values are not yet supported
+  assert(!opInst.getNumThreadsDims().has_value() &&
+         opInst.getNumThreadsValues().empty() &&
+         "Lowering of num_threads with dims modifier is NYI.");
   if (auto numThreadsVar = opInst.getNumThreads())
     numThreads = moduleTranslation.lookupValue(numThreadsVar);
   auto pbKind = llvm::omp::OMP_PROC_BIND_default;
@@ -5656,6 +5660,10 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value 
&numThreads,
               llvm_unreachable("unsupported host_eval use");
           })
           .Case([&](omp::ParallelOp parallelOp) {
+            // num_threads dims and values are not yet supported
+            assert(!parallelOp.getNumThreadsDims().has_value() &&
+                   parallelOp.getNumThreadsValues().empty() &&
+                   "Lowering of num_threads with dims modifier is NYI.");
             if (parallelOp.getNumThreads() == blockArg)
               numThreads = hostEvalVar;
             else
@@ -5777,8 +5785,13 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation 
*capturedOp,
       threadLimit = teamsOp.getThreadLimit();
     }
 
-    if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
+    if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
+      // num_threads dims and values are not yet supported
+      assert(!parallelOp.getNumThreadsDims().has_value() &&
+             parallelOp.getNumThreadsValues().empty() &&
+             "Lowering of num_threads with dims modifier is NYI.");
       numThreads = parallelOp.getNumThreads();
+    }
   }
 
   // Handle clauses impacting the number of teams.

>From c96be363f3d404e10310ed3dede806da2d5096ad Mon Sep 17 00:00:00 2001
From: skc7 <[email protected]>
Date: Thu, 11 Dec 2025 17:37:52 +0530
Subject: [PATCH 3/4] few more fixes

---
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      | 33 ++++++--------
 .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp    |  4 +-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 44 +++++++++----------
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  9 ++--
 mlir/test/Dialect/OpenMP/invalid.mlir         | 10 ++---
 5 files changed, 45 insertions(+), 55 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td 
b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 56cfd016ef52b..b5c51c5a8dff3 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1069,14 +1069,14 @@ class OpenMP_NumThreadsClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let arguments = (ins
-    ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_dims,
-    Variadic<AnyInteger>:$num_threads_values,
+    ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims,
+    Variadic<AnyInteger>:$num_threads_dims_values,
     Optional<IntLikeType>:$num_threads
   );
 
   let optAssemblyFormat = [{
     `num_threads` `(` custom<NumThreadsClause>(
-      $num_threads_dims, $num_threads_values, type($num_threads_values),
+      $num_threads_num_dims, $num_threads_dims_values, 
type($num_threads_dims_values),
       $num_threads, type($num_threads)
     ) `)`
   }];
@@ -1086,7 +1086,7 @@ class OpenMP_NumThreadsClauseSkip<
     space formed by the construct on which it appears.
 
     With dims modifier:
-    - Uses `num_threads_dims` (dimension count) and `num_threads_values` 
(upper bounds list)
+    - Uses `num_threads_num_dims` (dimension count) and 
`num_threads_dims_values` (upper bounds list)
     - Specifies upper bounds for each dimension (all must have same type)
     - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : 
type)`
     - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
@@ -1100,28 +1100,23 @@ class OpenMP_NumThreadsClauseSkip<
 
   let extraClassDeclaration = [{
     /// Returns true if the dims modifier is explicitly present
-    bool hasDimsModifier() {
-      return getNumThreadsDims().has_value();
+    bool hasNumThreadsDimsModifier() {
+      return getNumThreadsNumDims().has_value() && 
getNumThreadsNumDims().value();
     }
 
     /// Returns the number of dimensions specified by dims modifier
-    unsigned getNumDimensions() {
-      if (!hasDimsModifier())
+    unsigned getNumThreadsDimsCount() {
+      if (!hasNumThreadsDimsModifier())
         return 1;
-      return static_cast<unsigned>(*getNumThreadsDims());
-    }
-
-    /// Returns all dimension values as an operand range
-    ::mlir::OperandRange getDimensionValues() {
-      return getNumThreadsValues();
+      return static_cast<unsigned>(*getNumThreadsNumDims());
     }
 
     /// Returns the value for a specific dimension index
-    /// Index must be less than getNumDimensions()
-    ::mlir::Value getDimensionValue(unsigned index) {
-      assert(index < getDimensionValues().size() &&
-             "Dimension index out of bounds");
-      return getDimensionValues()[index];
+    /// Index must be less than getNumThreadsDimsCount()
+    ::mlir::Value getNumThreadsDimsValue(unsigned index) {
+      assert(index < getNumThreadsDimsCount() &&
+             "Num threads dims index out of bounds");
+      return getNumThreadsDimsValues()[index];
     }
   }];
 }
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp 
b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 0d5333ec2e455..ab7bded7835be 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -448,8 +448,8 @@ struct ParallelOpLowering : public 
OpRewritePattern<scf::ParallelOp> {
         /* allocate_vars = */ llvm::SmallVector<Value>{},
         /* allocator_vars = */ llvm::SmallVector<Value>{},
         /* if_expr = */ Value{},
-        /* num_threads_dims = */ nullptr,
-        /* num_threads_values = */ llvm::SmallVector<Value>{},
+        /* num_threads_num_dims = */ nullptr,
+        /* num_threads_dims_values = */ llvm::SmallVector<Value>{},
         /* num_threads = */ numThreadsVar,
         /* private_vars = */ ValueRange(),
         /* private_syms = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp 
b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index f1f3e5c0b691b..1c8642f3d5290 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2553,7 +2553,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState 
&state,
   MLIRContext *ctx = builder.getContext();
   ParallelOp::build(
       builder, state, clauses.allocateVars, clauses.allocatorVars,
-      clauses.ifExpr, clauses.numThreadsDims, clauses.numThreadsValues,
+      clauses.ifExpr, clauses.numThreadsNumDims, clauses.numThreadsDimsValues,
       clauses.numThreads, clauses.privateVars,
       makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
       clauses.procBindKind, clauses.reductionMod, clauses.reductionVars,
@@ -2604,30 +2604,28 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
   return success();
 }
 
-LogicalResult ParallelOp::verify() {
-  // verify num_threads clause restrictions
-  auto numThreadsDims = getNumThreadsDims();
-  auto numThreadsValues = getNumThreadsValues();
-  auto numThreads = getNumThreads();
-
-  // num_threads with dims modifier
-  if (numThreadsDims.has_value() && numThreadsValues.empty()) {
-    return emitError(
-        "num_threads dims modifier requires values to be specified");
-  }
-
-  if (numThreadsDims.has_value() &&
-      numThreadsValues.size() != static_cast<size_t>(*numThreadsDims)) {
-    return emitError("num_threads dims(")
-           << *numThreadsDims << ") specified but " << numThreadsValues.size()
-           << " values provided";
+// Helper: Verify num_threads clause
+LogicalResult
+verifyNumThreadsClause(Operation *op,
+                       std::optional<IntegerAttr> numThreadsNumDims,
+                       OperandRange numThreadsDimsValues, Value numThreads) {
+  bool hasDimsModifier =
+      numThreadsNumDims.has_value() && numThreadsNumDims.value();
+  if (hasDimsModifier && numThreads) {
+    return op->emitError("num_threads with dims modifier cannot be used "
+                         "together with number of threads");
   }
+  if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues)))
+    return failure();
+  return success();
+}
 
-  // num_threads dims and number of threads cannot be used together
-  if (numThreadsDims.has_value() && numThreads) {
-    return emitError(
-        "num_threads dims and number of threads cannot be used together");
-  }
+LogicalResult ParallelOp::verify() {
+  // verify num_threads clause restrictions
+  if (failed(verifyNumThreadsClause(
+          getOperation(), this->getNumThreadsNumDimsAttr(),
+          this->getNumThreadsDimsValues(), this->getNumThreads())))
+    return failure();
 
   // verify allocate clause restrictions
   if (getAllocateVars().size() != getAllocatorVars().size())
diff --git 
a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp 
b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index cdeb1872ca92a..7eb8a90bd4c00 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2889,8 +2889,7 @@ convertOmpParallel(omp::ParallelOp opInst, 
llvm::IRBuilderBase &builder,
     ifCond = moduleTranslation.lookupValue(ifVar);
   llvm::Value *numThreads = nullptr;
   // num_threads dims and values are not yet supported
-  assert(!opInst.getNumThreadsDims().has_value() &&
-         opInst.getNumThreadsValues().empty() &&
+  assert(!opInst.hasNumThreadsDimsModifier() &&
          "Lowering of num_threads with dims modifier is NYI.");
   if (auto numThreadsVar = opInst.getNumThreads())
     numThreads = moduleTranslation.lookupValue(numThreadsVar);
@@ -5661,8 +5660,7 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value 
&numThreads,
           })
           .Case([&](omp::ParallelOp parallelOp) {
             // num_threads dims and values are not yet supported
-            assert(!parallelOp.getNumThreadsDims().has_value() &&
-                   parallelOp.getNumThreadsValues().empty() &&
+            assert(!parallelOp.hasNumThreadsDimsModifier() &&
                    "Lowering of num_threads with dims modifier is NYI.");
             if (parallelOp.getNumThreads() == blockArg)
               numThreads = hostEvalVar;
@@ -5787,8 +5785,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation 
*capturedOp,
 
     if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
       // num_threads dims and values are not yet supported
-      assert(!parallelOp.getNumThreadsDims().has_value() &&
-             parallelOp.getNumThreadsValues().empty() &&
+      assert(!parallelOp.hasNumThreadsDimsModifier() &&
              "Lowering of num_threads with dims modifier is NYI.");
       numThreads = parallelOp.getNumThreads();
     }
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir 
b/mlir/test/Dialect/OpenMP/invalid.mlir
index 9e2e5722aab9f..db0ddcb415d42 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -31,17 +31,17 @@ func.func @num_threads_once(%n : si32) {
 // -----
 
 func.func @num_threads_dims_no_values() {
-  // expected-error@+1 {{num_threads dims modifier requires values to be 
specified}}
+  // expected-error@+1 {{dims modifier requires values to be specified}}
   "omp.parallel"() ({
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_dims = 2 : 
i64} : () -> ()
+  }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_num_dims = 
2 : i64} : () -> ()
   return
 }
 
 // -----
 
 func.func @num_threads_dims_mismatch(%n : i64) {
-  // expected-error@+1 {{num_threads dims(2) specified but 1 values provided}}
+  // expected-error@+1 {{dims(2) specified but 1 values provided}}
   omp.parallel num_threads(dims(2): %n : i64) {
     omp.terminator
   }
@@ -52,10 +52,10 @@ func.func @num_threads_dims_mismatch(%n : i64) {
 // -----
 
 func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) {
-  // expected-error@+1 {{num_threads dims and number of threads cannot be used 
together}}
+  // expected-error@+1 {{num_threads with dims modifier cannot be used 
together with number of threads}}
   "omp.parallel"(%n, %n, %m) ({
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_dims = 2 : 
i64} : (i64, i64, i64) -> ()
+  }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_num_dims = 
2 : i64} : (i64, i64, i64) -> ()
   return
 }
 

>From c6acf73667df2016a5d89aa9b71c2a727a42f67c Mon Sep 17 00:00:00 2001
From: skc7 <[email protected]>
Date: Fri, 19 Dec 2025 12:27:38 +0530
Subject: [PATCH 4/4] Use num_threads_dims_values only

---
 flang/lib/Lower/OpenMP/ClauseProcessor.cpp    |  4 +-
 flang/lib/Lower/OpenMP/OpenMP.cpp             | 15 ++---
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      | 15 +++--
 .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp    |  5 +-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 62 ++++++++-----------
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 16 ++---
 mlir/test/Dialect/OpenMP/invalid.mlir         | 12 ++--
 mlir/test/Dialect/OpenMP/ops.mlir             | 10 +--
 8 files changed, 66 insertions(+), 73 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp 
b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 3c31b3a07f57f..1a2700f50508e 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -514,8 +514,8 @@ bool ClauseProcessor::processNumThreads(
     mlir::omp::NumThreadsClauseOps &result) const {
   if (auto *clause = findUniqueClause<omp::clause::NumThreads>()) {
     // OMPIRBuilder expects `NUM_THREADS` clause as a `Value`.
-    result.numThreads =
-        fir::getBase(converter.genExprValue(clause->v, stmtCtx));
+    result.numThreadsDimsValues.push_back(
+        fir::getBase(converter.genExprValue(clause->v, stmtCtx)));
     return true;
   }
   return false;
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp 
b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 38ab42076f559..80219e133a11c 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -99,8 +99,8 @@ class HostEvalInfo {
     if (ops.numTeamsUpper)
       vars.push_back(ops.numTeamsUpper);
 
-    if (ops.numThreads)
-      vars.push_back(ops.numThreads);
+    for (auto numThreads : ops.numThreadsDimsValues)
+      vars.push_back(numThreads);
 
     if (ops.threadLimit)
       vars.push_back(ops.threadLimit);
@@ -115,7 +115,8 @@ class HostEvalInfo {
     assert(args.size() ==
                ops.loopLowerBounds.size() + ops.loopUpperBounds.size() +
                    ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) +
-                   (ops.numTeamsUpper ? 1 : 0) + (ops.numThreads ? 1 : 0) +
+                   (ops.numTeamsUpper ? 1 : 0) +
+                   ops.numThreadsDimsValues.size() +
                    (ops.threadLimit ? 1 : 0) &&
            "invalid block argument list");
     int argIndex = 0;
@@ -134,8 +135,8 @@ class HostEvalInfo {
     if (ops.numTeamsUpper)
       ops.numTeamsUpper = args[argIndex++];
 
-    if (ops.numThreads)
-      ops.numThreads = args[argIndex++];
+    for (size_t i = 0; i < ops.numThreadsDimsValues.size(); ++i)
+      ops.numThreadsDimsValues[i] = args[argIndex++];
 
     if (ops.threadLimit)
       ops.threadLimit = args[argIndex++];
@@ -169,13 +170,13 @@ class HostEvalInfo {
   /// \returns whether an update was performed. If not, these clauses were not
   ///          evaluated in the host device.
   bool apply(mlir::omp::ParallelOperands &clauseOps) {
-    if (!ops.numThreads || parallelApplied) {
+    if (ops.numThreadsDimsValues.empty() || parallelApplied) {
       parallelApplied = true;
       return false;
     }
 
     parallelApplied = true;
-    clauseOps.numThreads = ops.numThreads;
+    clauseOps.numThreadsDimsValues = ops.numThreadsDimsValues;
     return true;
   }
 
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td 
b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index b5c51c5a8dff3..7236baf33a15b 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1070,14 +1070,12 @@ class OpenMP_NumThreadsClauseSkip<
                     extraClassDeclaration> {
   let arguments = (ins
     ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims,
-    Variadic<AnyInteger>:$num_threads_dims_values,
-    Optional<IntLikeType>:$num_threads
+    Variadic<IntLikeType>:$num_threads_dims_values
   );
 
   let optAssemblyFormat = [{
     `num_threads` `(` custom<NumThreadsClause>(
-      $num_threads_num_dims, $num_threads_dims_values, 
type($num_threads_dims_values),
-      $num_threads, type($num_threads)
+      $num_threads_num_dims, $num_threads_dims_values, 
type($num_threads_dims_values)
     ) `)`
   }];
 
@@ -1092,10 +1090,9 @@ class OpenMP_NumThreadsClauseSkip<
     - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
 
     Without dims modifier:
-    - Uses `num_threads`
-    - If lower bound not specified, it defaults to upper bound value
-    - Format: `num_threads(bounds : type)`
-    - Example: `num_threads(%ub : i32)`
+    - The number of threads is specified by single value in 
`num_threads_dims_values`
+    - Format: `num_threads(value : type)`
+    - Example: `num_threads(%n : i32)`
   }];
 
   let extraClassDeclaration = [{
@@ -1116,6 +1113,8 @@ class OpenMP_NumThreadsClauseSkip<
     ::mlir::Value getNumThreadsDimsValue(unsigned index) {
       assert(index < getNumThreadsDimsCount() &&
              "Num threads dims index out of bounds");
+      if(getNumThreadsDimsValues().empty())
+        return nullptr;
       return getNumThreadsDimsValues()[index];
     }
   }];
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp 
b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index ab7bded7835be..5d75613f9b2b6 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -438,9 +438,11 @@ struct ParallelOpLowering : public 
OpRewritePattern<scf::ParallelOp> {
     rewriter.eraseOp(reduce);
 
     Value numThreadsVar;
+    SmallVector<Value> numThreadsValues;
     if (numThreads > 0) {
       numThreadsVar = LLVM::ConstantOp::create(
           rewriter, loc, rewriter.getI32IntegerAttr(numThreads));
+      numThreadsValues.push_back(numThreadsVar);
     }
     // Create the parallel wrapper.
     auto ompParallel = omp::ParallelOp::create(
@@ -449,8 +451,7 @@ struct ParallelOpLowering : public 
OpRewritePattern<scf::ParallelOp> {
         /* allocator_vars = */ llvm::SmallVector<Value>{},
         /* if_expr = */ Value{},
         /* num_threads_num_dims = */ nullptr,
-        /* num_threads_dims_values = */ llvm::SmallVector<Value>{},
-        /* num_threads = */ numThreadsVar,
+        /* num_threads_dims_values = */ numThreadsValues,
         /* private_vars = */ ValueRange(),
         /* private_syms = */ nullptr,
         /* private_needs_barrier = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp 
b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 1c8642f3d5290..26730443c96c9 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2286,7 +2286,8 @@ LogicalResult TargetOp::verifyRegions() {
       if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
         if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
             parallelOp->isAncestor(capturedOp) &&
-            hostEvalArg == parallelOp.getNumThreads())
+            llvm::is_contained(parallelOp.getNumThreadsDimsValues(),
+                               hostEvalArg))
           continue;
 
         return emitOpError()
@@ -2540,7 +2541,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState 
&state,
                     /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
                     /*num_threads_dims=*/nullptr,
                     /*num_threads_values=*/ValueRange(),
-                    /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
+                    /*private_vars=*/ValueRange(),
                     /*private_syms=*/nullptr, 
/*private_needs_barrier=*/nullptr,
                     /*proc_bind_kind=*/nullptr,
                     /*reduction_mod =*/nullptr, 
/*reduction_vars=*/ValueRange(),
@@ -2551,14 +2552,14 @@ void ParallelOp::build(OpBuilder &builder, 
OperationState &state,
 void ParallelOp::build(OpBuilder &builder, OperationState &state,
                        const ParallelOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
-  ParallelOp::build(
-      builder, state, clauses.allocateVars, clauses.allocatorVars,
-      clauses.ifExpr, clauses.numThreadsNumDims, clauses.numThreadsDimsValues,
-      clauses.numThreads, clauses.privateVars,
-      makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
-      clauses.procBindKind, clauses.reductionMod, clauses.reductionVars,
-      makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
-      makeArrayAttr(ctx, clauses.reductionSyms));
+  ParallelOp::build(builder, state, clauses.allocateVars, 
clauses.allocatorVars,
+                    clauses.ifExpr, clauses.numThreadsNumDims,
+                    clauses.numThreadsDimsValues, clauses.privateVars,
+                    makeArrayAttr(ctx, clauses.privateSyms),
+                    clauses.privateNeedsBarrier, clauses.procBindKind,
+                    clauses.reductionMod, clauses.reductionVars,
+                    makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+                    makeArrayAttr(ctx, clauses.reductionSyms));
 }
 
 template <typename OpType>
@@ -2608,13 +2609,7 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
 LogicalResult
 verifyNumThreadsClause(Operation *op,
                        std::optional<IntegerAttr> numThreadsNumDims,
-                       OperandRange numThreadsDimsValues, Value numThreads) {
-  bool hasDimsModifier =
-      numThreadsNumDims.has_value() && numThreadsNumDims.value();
-  if (hasDimsModifier && numThreads) {
-    return op->emitError("num_threads with dims modifier cannot be used "
-                         "together with number of threads");
-  }
+                       OperandRange numThreadsDimsValues) {
   if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues)))
     return failure();
   return success();
@@ -2622,9 +2617,9 @@ verifyNumThreadsClause(Operation *op,
 
 LogicalResult ParallelOp::verify() {
   // verify num_threads clause restrictions
-  if (failed(verifyNumThreadsClause(
-          getOperation(), this->getNumThreadsNumDimsAttr(),
-          this->getNumThreadsDimsValues(), this->getNumThreads())))
+  if (failed(verifyNumThreadsClause(getOperation(),
+                                    this->getNumThreadsNumDimsAttr(),
+                                    this->getNumThreadsDimsValues())))
     return failure();
 
   // verify allocate clause restrictions
@@ -4686,33 +4681,28 @@ static void printNumTeamsClause(OpAsmPrinter &p, 
Operation *op,
 static ParseResult
 parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
                       SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
-                      SmallVectorImpl<Type> &types,
-                      std::optional<OpAsmParser::UnresolvedOperand> &bounds,
-                      Type &boundsType) {
+                      SmallVectorImpl<Type> &types) {
   if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) 
{
     return success();
   }
 
-  OpAsmParser::UnresolvedOperand boundsOperand;
-  if (parser.parseOperand(boundsOperand) || parser.parseColon() ||
-      parser.parseType(boundsType)) {
+  // Without dims modifier: value : type
+  OpAsmParser::UnresolvedOperand singleValue;
+  Type singleType;
+  if (parser.parseOperand(singleValue) || parser.parseColon() ||
+      parser.parseType(singleType)) {
     return failure();
   }
-  bounds = boundsOperand;
+  values.push_back(singleValue);
+  types.push_back(singleType);
   return success();
 }
 
 static void printNumThreadsClause(OpAsmPrinter &p, Operation *op,
                                   IntegerAttr dimsAttr, OperandRange values,
-                                  TypeRange types, Value bounds,
-                                  Type boundsType) {
-  if (!values.empty()) {
-    printDimsModifierWithValues(p, dimsAttr, values, types);
-  }
-  if (bounds) {
-    p.printOperand(bounds);
-    p << " : " << boundsType;
-  }
+                                  TypeRange types) {
+  // Multidimensional: dims(N): values : type
+  printDimsModifierWithValues(p, dimsAttr, values, types);
 }
 
 #define GET_ATTRDEF_CLASSES
diff --git 
a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp 
b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 7eb8a90bd4c00..b90507ec7851f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2890,8 +2890,8 @@ convertOmpParallel(omp::ParallelOp opInst, 
llvm::IRBuilderBase &builder,
   llvm::Value *numThreads = nullptr;
   // num_threads dims and values are not yet supported
   assert(!opInst.hasNumThreadsDimsModifier() &&
-         "Lowering of num_threads with dims modifier is NYI.");
-  if (auto numThreadsVar = opInst.getNumThreads())
+         "Lowering of num_threads with dims modifier is not yet implemented.");
+  if (auto numThreadsVar = opInst.getNumThreadsDimsValue(0))
     numThreads = moduleTranslation.lookupValue(numThreadsVar);
   auto pbKind = llvm::omp::OMP_PROC_BIND_default;
   if (auto bind = opInst.getProcBindKind())
@@ -5661,8 +5661,9 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value 
&numThreads,
           .Case([&](omp::ParallelOp parallelOp) {
             // num_threads dims and values are not yet supported
             assert(!parallelOp.hasNumThreadsDimsModifier() &&
-                   "Lowering of num_threads with dims modifier is NYI.");
-            if (parallelOp.getNumThreads() == blockArg)
+                   "Lowering of num_threads with dims modifier is not yet "
+                   "implemented.");
+            if (parallelOp.getNumThreadsDimsValue(0) == blockArg)
               numThreads = hostEvalVar;
             else
               llvm_unreachable("unsupported host_eval use");
@@ -5785,9 +5786,10 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation 
*capturedOp,
 
     if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
       // num_threads dims and values are not yet supported
-      assert(!parallelOp.hasNumThreadsDimsModifier() &&
-             "Lowering of num_threads with dims modifier is NYI.");
-      numThreads = parallelOp.getNumThreads();
+      assert(
+          !parallelOp.hasNumThreadsDimsModifier() &&
+          "Lowering of num_threads with dims modifier is not yet 
implemented.");
+      numThreads = parallelOp.getNumThreadsDimsValue(0);
     }
   }
 
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir 
b/mlir/test/Dialect/OpenMP/invalid.mlir
index db0ddcb415d42..b05bbd4056525 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -34,7 +34,7 @@ func.func @num_threads_dims_no_values() {
   // expected-error@+1 {{dims modifier requires values to be specified}}
   "omp.parallel"() ({
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_num_dims = 
2 : i64} : () -> ()
+  }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0>, num_threads_num_dims = 2 
: i64} : () -> ()
   return
 }
 
@@ -51,11 +51,11 @@ func.func @num_threads_dims_mismatch(%n : i64) {
 
 // -----
 
-func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) {
-  // expected-error@+1 {{num_threads with dims modifier cannot be used 
together with number of threads}}
-  "omp.parallel"(%n, %n, %m) ({
+func.func @num_threads_multiple_values_without_dims(%n : i64, %m: i64) {
+  // expected-error@+1 {{dims values can only be specified with dims modifier}}
+  "omp.parallel"(%n, %m) ({
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_num_dims = 
2 : i64} : (i64, i64, i64) -> ()
+  }) {operandSegmentSizes = array<i32: 0,0,0,2,0,0>} : (i64, i64) -> ()
   return
 }
 
@@ -2797,7 +2797,7 @@ func.func @undefined_privatizer(%arg0: index) {
 // -----
 func.func @undefined_privatizer(%arg0: !llvm.ptr) {
   // expected-error @below {{inconsistent number of private variables and 
privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
-  "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 
0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
+  "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 1, 0>, 
private_syms = [@x.privatizer, @y.privatizer]}> ({
     ^bb0(%arg2: !llvm.ptr):
       omp.terminator
     }) : (!llvm.ptr) -> ()
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir 
b/mlir/test/Dialect/OpenMP/ops.mlir
index 585c9483c08a9..004313eaa6ff1 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -73,7 +73,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : 
i1, %num_threads : i
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : 
memref<i32>) num_threads(%{{.*}} : i32)
     "omp.parallel"(%data_var, %data_var, %num_threads) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 1,1,0,0,1,0,0>} : (memref<i32>, 
memref<i32>, i32) -> ()
+    }) {operandSegmentSizes = array<i32: 1,1,0,1,0,0>} : (memref<i32>, 
memref<i32>, i32) -> ()
 
   // CHECK: omp.barrier
     omp.barrier
@@ -82,22 +82,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : 
i1, %num_threads : i
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : 
memref<i32>) if(%{{.*}})
     "omp.parallel"(%data_var, %data_var, %if_cond) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : (memref<i32>, 
memref<i32>, i1) -> ()
+    }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : (memref<i32>, 
memref<i32>, i1) -> ()
 
   // test without allocate
   // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
     "omp.parallel"(%if_cond, %num_threads) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 0,0,1,0,1,0,0>} : (i1, i32) -> ()
+    }) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (i1, i32) -> ()
 
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 1,1,1,0,1,0,0>, proc_bind_kind = 
#omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
+  }) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_kind = 
#omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
 
   // test with multiple parameters for single variadic argument
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : 
memref<i32>)
   "omp.parallel" (%data_var, %data_var) ({
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0,0>} : (memref<i32>, 
memref<i32>) -> ()
+  }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, 
memref<i32>) -> ()
 
   // CHECK: omp.parallel
   omp.parallel {

_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to