https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/175790
>From 9ced34d1958c0da71fc8b87f20671acab0e4cd54 Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Tue, 13 Jan 2026 21:24:34 +0530 Subject: [PATCH] [FLANG] Add flang to mlir lowering for num_teams --- flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 37 ++++++++++----- flang/lib/Lower/OpenMP/Clauses.cpp | 27 +++++++++-- flang/lib/Lower/OpenMP/OpenMP.cpp | 18 ++++++-- flang/test/Lower/OpenMP/num-teams-dims.f90 | 52 ++++++++++++++++++++++ 4 files changed, 117 insertions(+), 17 deletions(-) create mode 100644 flang/test/Lower/OpenMP/num-teams-dims.f90 diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index b923e415231d6..579ce359ed357 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -495,17 +495,34 @@ bool ClauseProcessor::processSizes(StatementContext &stmtCtx, bool ClauseProcessor::processNumTeams( lower::StatementContext &stmtCtx, mlir::omp::NumTeamsClauseOps &result) const { - // TODO Get lower and upper bounds for num_teams when parser is updated to - // accept both. if (auto *clause = findUniqueClause<omp::clause::NumTeams>()) { - // The num_teams directive accepts a list of team lower/upper bounds. - // This is an extension to support grid specification for ompx_bare. - // Here, only expect a single element in the list. - assert(clause->v.size() == 1); - // auto lowerBound = std::get<std::optional<ExprTy>>(clause->v[0]->t); - auto &upperBound = std::get<ExprTy>(clause->v[0].t); - result.numTeamsUpper = - fir::getBase(converter.genExprValue(upperBound, stmtCtx)); + // The num_teams clause accepts a list of upper bounds. + // With dims modifier: multiple upper bounds for multi-dimensional grid + // Without dims modifier: single Range with optional lower/upper bounds + assert(!clause->v.empty()); + + // Check if dims modifier is present (indicated by having multiple elements + // in the list, or single element without lower bound but with multiple + // upper bounds from dims modifier parsing) + if (clause->v.size() > 1) { + // Dims modifier case: multiple upper bounds + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + result.numTeamsNumDims = firOpBuilder.getI64IntegerAttr(clause->v.size()); + for (const auto &range : clause->v) { + auto &upperBound = std::get<ExprTy>(range.t); + result.numTeamsDimsValues.push_back( + fir::getBase(converter.genExprValue(upperBound, stmtCtx))); + } + } else { + // Legacy case: single element with optional lower and upper bounds + auto &lowerBound = std::get<std::optional<ExprTy>>(clause->v[0].t); + auto &upperBound = std::get<ExprTy>(clause->v[0].t); + if (lowerBound) + result.numTeamsLower = + fir::getBase(converter.genExprValue(*lowerBound, stmtCtx)); + result.numTeamsUpper = + fir::getBase(converter.genExprValue(upperBound, stmtCtx)); + } return true; } return false; diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp index a2716fb22a75c..01d0ff963ecf3 100644 --- a/flang/lib/Lower/OpenMP/Clauses.cpp +++ b/flang/lib/Lower/OpenMP/Clauses.cpp @@ -1297,10 +1297,29 @@ NumTasks make(const parser::OmpClause::NumTasks &inp, NumTeams make(const parser::OmpClause::NumTeams &inp, semantics::SemanticsContext &semaCtx) { // inp.v -> parser::OmpNumTeamsClause - auto &t1 = std::get<std::list<parser::ScalarIntExpr>>(inp.v.t); - assert(!t1.empty()); - List<NumTeams::Range> v{{{/*LowerBound=*/std::nullopt, - /*UpperBound=*/makeExpr(t1.front(), semaCtx)}}}; + auto &mods = semantics::OmpGetModifiers(inp.v); + auto *dims = semantics::OmpGetUniqueModifier<parser::OmpDimsModifier>(mods); + auto *lowerBound = + semantics::OmpGetUniqueModifier<parser::OmpLowerBound>(mods); + auto &values = std::get<std::list<parser::ScalarIntExpr>>(inp.v.t); + assert(!values.empty()); + + // With dims modifier: create Range for each value (all upper bounds) + // The dims modifier value is stored as the list size matching dims count. + // Without dims modifier: single Range with optional lower bound + if (dims) { + List<NumTeams::Range> v; + for (const auto &val : values) { + v.push_back(NumTeams::Range{{/*LowerBound=*/std::nullopt, + /*UpperBound=*/makeExpr(val, semaCtx)}}); + } + return NumTeams{/*List=*/v}; + } + + // Without dims modifier: single element with optional lower bound + auto lb = maybeApplyToV(makeExprFn(semaCtx), lowerBound); + List<NumTeams::Range> v{{{/*LowerBound=*/lb, + /*UpperBound=*/makeExpr(values.front(), semaCtx)}}}; return NumTeams{/*List=*/v}; } diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 989e370870f33..e7a3cfcf52cd2 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -99,6 +99,10 @@ class HostEvalInfo { if (ops.numTeamsUpper) vars.push_back(ops.numTeamsUpper); + // num_teams with dims modifier (OpenMP 6.1) + for (mlir::Value val : ops.numTeamsDimsValues) + vars.push_back(val); + if (ops.numThreads) vars.push_back(ops.numThreads); @@ -115,8 +119,8 @@ class HostEvalInfo { assert(args.size() == ops.loopLowerBounds.size() + ops.loopUpperBounds.size() + ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) + - (ops.numTeamsUpper ? 1 : 0) + (ops.numThreads ? 1 : 0) + - (ops.threadLimit ? 1 : 0) && + (ops.numTeamsUpper ? 1 : 0) + ops.numTeamsDimsValues.size() + + (ops.numThreads ? 1 : 0) + (ops.threadLimit ? 1 : 0) && "invalid block argument list"); int argIndex = 0; for (size_t i = 0; i < ops.loopLowerBounds.size(); ++i) @@ -134,6 +138,10 @@ class HostEvalInfo { if (ops.numTeamsUpper) ops.numTeamsUpper = args[argIndex++]; + // num_teams with dims modifier (OpenMP 6.1) + for (size_t i = 0; i < ops.numTeamsDimsValues.size(); ++i) + ops.numTeamsDimsValues[i] = args[argIndex++]; + if (ops.numThreads) ops.numThreads = args[argIndex++]; @@ -185,11 +193,15 @@ class HostEvalInfo { /// \returns whether an update was performed. If not, these clauses were not /// evaluated in the host device. bool apply(mlir::omp::TeamsOperands &clauseOps) { - if (!ops.numTeamsLower && !ops.numTeamsUpper && !ops.threadLimit) + if (!ops.numTeamsLower && !ops.numTeamsUpper && + ops.numTeamsDimsValues.empty() && !ops.threadLimit) return false; clauseOps.numTeamsLower = ops.numTeamsLower; clauseOps.numTeamsUpper = ops.numTeamsUpper; + // num_teams with dims modifier (OpenMP 6.1) + clauseOps.numTeamsDimsValues = ops.numTeamsDimsValues; + clauseOps.numTeamsNumDims = ops.numTeamsNumDims; clauseOps.threadLimit = ops.threadLimit; return true; } diff --git a/flang/test/Lower/OpenMP/num-teams-dims.f90 b/flang/test/Lower/OpenMP/num-teams-dims.f90 new file mode 100644 index 0000000000000..fd5d5ba40c804 --- /dev/null +++ b/flang/test/Lower/OpenMP/num-teams-dims.f90 @@ -0,0 +1,52 @@ +! RUN: %flang_fc1 -emit-hlfir %openmp_flags -fopenmp-version=61 %s -o - | FileCheck %s + +!=============================================================================== +! `num_teams` clause with dims modifier (OpenMP 6.1) +!=============================================================================== + +! CHECK-LABEL: func @_QPteams_numteams_dims2 +subroutine teams_numteams_dims2() + ! CHECK: omp.teams + ! CHECK-SAME: num_teams(dims(2): %{{.*}}, %{{.*}} : i32) + !$omp teams num_teams(dims(2): 10, 4) + call f1() + ! CHECK: omp.terminator + !$omp end teams +end subroutine teams_numteams_dims2 + +! CHECK-LABEL: func @_QPteams_numteams_dims3 +subroutine teams_numteams_dims3() + ! CHECK: omp.teams + ! CHECK-SAME: num_teams(dims(3): %{{.*}}, %{{.*}}, %{{.*}} : i32) + !$omp teams num_teams(dims(3): 8, 4, 2) + call f1() + ! CHECK: omp.terminator + !$omp end teams +end subroutine teams_numteams_dims3 + +! CHECK-LABEL: func @_QPteams_numteams_dims_var +subroutine teams_numteams_dims_var(a, b, c) + integer, intent(in) :: a, b, c + ! CHECK: omp.teams + ! CHECK-SAME: num_teams(dims(3): %{{.*}}, %{{.*}}, %{{.*}} : i32) + !$omp teams num_teams(dims(3): a, b, c) + call f1() + ! CHECK: omp.terminator + !$omp end teams +end subroutine teams_numteams_dims_var + +!=============================================================================== +! `num_teams` clause with lower bound (legacy, without dims) +!=============================================================================== + +! CHECK-LABEL: func @_QPteams_numteams_lower_upper +subroutine teams_numteams_lower_upper(lower, upper) + integer, intent(in) :: lower, upper + ! CHECK: omp.teams + ! CHECK-SAME: num_teams(%{{.*}} : i32 to %{{.*}} : i32) + !$omp teams num_teams(lower: upper) + call f1() + ! CHECK: omp.terminator + !$omp end teams +end subroutine teams_numteams_lower_upper + _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
