https://github.com/skc7 created https://github.com/llvm/llvm-project/pull/175790
None >From 75cf8e2211eb9dce46ed9f6f5e57643efaddf280 Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Tue, 13 Jan 2026 21:24:34 +0530 Subject: [PATCH] [FLANG] Add flang to mlir lowering for num_teams --- flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 37 ++++++++++----- flang/lib/Lower/OpenMP/Clauses.cpp | 26 +++++++++-- flang/test/Lower/OpenMP/num-teams-dims.f90 | 52 ++++++++++++++++++++++ 3 files changed, 101 insertions(+), 14 deletions(-) create mode 100644 flang/test/Lower/OpenMP/num-teams-dims.f90 diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index b923e415231d6..579ce359ed357 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -495,17 +495,34 @@ bool ClauseProcessor::processSizes(StatementContext &stmtCtx, bool ClauseProcessor::processNumTeams( lower::StatementContext &stmtCtx, mlir::omp::NumTeamsClauseOps &result) const { - // TODO Get lower and upper bounds for num_teams when parser is updated to - // accept both. if (auto *clause = findUniqueClause<omp::clause::NumTeams>()) { - // The num_teams directive accepts a list of team lower/upper bounds. - // This is an extension to support grid specification for ompx_bare. - // Here, only expect a single element in the list. - assert(clause->v.size() == 1); - // auto lowerBound = std::get<std::optional<ExprTy>>(clause->v[0]->t); - auto &upperBound = std::get<ExprTy>(clause->v[0].t); - result.numTeamsUpper = - fir::getBase(converter.genExprValue(upperBound, stmtCtx)); + // The num_teams clause accepts a list of upper bounds. + // With dims modifier: multiple upper bounds for multi-dimensional grid + // Without dims modifier: single Range with optional lower/upper bounds + assert(!clause->v.empty()); + + // Check if dims modifier is present (indicated by having multiple elements + // in the list, or single element without lower bound but with multiple + // upper bounds from dims modifier parsing) + if (clause->v.size() > 1) { + // Dims modifier case: multiple upper bounds + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + result.numTeamsNumDims = firOpBuilder.getI64IntegerAttr(clause->v.size()); + for (const auto &range : clause->v) { + auto &upperBound = std::get<ExprTy>(range.t); + result.numTeamsDimsValues.push_back( + fir::getBase(converter.genExprValue(upperBound, stmtCtx))); + } + } else { + // Legacy case: single element with optional lower and upper bounds + auto &lowerBound = std::get<std::optional<ExprTy>>(clause->v[0].t); + auto &upperBound = std::get<ExprTy>(clause->v[0].t); + if (lowerBound) + result.numTeamsLower = + fir::getBase(converter.genExprValue(*lowerBound, stmtCtx)); + result.numTeamsUpper = + fir::getBase(converter.genExprValue(upperBound, stmtCtx)); + } return true; } return false; diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp index a2716fb22a75c..a2a952c2c2ba8 100644 --- a/flang/lib/Lower/OpenMP/Clauses.cpp +++ b/flang/lib/Lower/OpenMP/Clauses.cpp @@ -1297,10 +1297,28 @@ NumTasks make(const parser::OmpClause::NumTasks &inp, NumTeams make(const parser::OmpClause::NumTeams &inp, semantics::SemanticsContext &semaCtx) { // inp.v -> parser::OmpNumTeamsClause - auto &t1 = std::get<std::list<parser::ScalarIntExpr>>(inp.v.t); - assert(!t1.empty()); - List<NumTeams::Range> v{{{/*LowerBound=*/std::nullopt, - /*UpperBound=*/makeExpr(t1.front(), semaCtx)}}}; + auto &mods = semantics::OmpGetModifiers(inp.v); + auto *dims = semantics::OmpGetUniqueModifier<parser::OmpDimsModifier>(mods); + auto *lowerBound = semantics::OmpGetUniqueModifier<parser::OmpLowerBound>(mods); + auto &values = std::get<std::list<parser::ScalarIntExpr>>(inp.v.t); + assert(!values.empty()); + + // With dims modifier: create Range for each value (all upper bounds) + // The dims modifier value is stored as the list size matching dims count. + // Without dims modifier: single Range with optional lower bound + if (dims) { + List<NumTeams::Range> v; + for (const auto &val : values) { + v.push_back(NumTeams::Range{{/*LowerBound=*/std::nullopt, + /*UpperBound=*/makeExpr(val, semaCtx)}}); + } + return NumTeams{/*List=*/v}; + } + + // Without dims modifier: single element with optional lower bound + auto lb = maybeApplyToV(makeExprFn(semaCtx), lowerBound); + List<NumTeams::Range> v{{{/*LowerBound=*/lb, + /*UpperBound=*/makeExpr(values.front(), semaCtx)}}}; return NumTeams{/*List=*/v}; } diff --git a/flang/test/Lower/OpenMP/num-teams-dims.f90 b/flang/test/Lower/OpenMP/num-teams-dims.f90 new file mode 100644 index 0000000000000..fd5d5ba40c804 --- /dev/null +++ b/flang/test/Lower/OpenMP/num-teams-dims.f90 @@ -0,0 +1,52 @@ +! RUN: %flang_fc1 -emit-hlfir %openmp_flags -fopenmp-version=61 %s -o - | FileCheck %s + +!=============================================================================== +! `num_teams` clause with dims modifier (OpenMP 6.1) +!=============================================================================== + +! CHECK-LABEL: func @_QPteams_numteams_dims2 +subroutine teams_numteams_dims2() + ! CHECK: omp.teams + ! CHECK-SAME: num_teams(dims(2): %{{.*}}, %{{.*}} : i32) + !$omp teams num_teams(dims(2): 10, 4) + call f1() + ! CHECK: omp.terminator + !$omp end teams +end subroutine teams_numteams_dims2 + +! CHECK-LABEL: func @_QPteams_numteams_dims3 +subroutine teams_numteams_dims3() + ! CHECK: omp.teams + ! CHECK-SAME: num_teams(dims(3): %{{.*}}, %{{.*}}, %{{.*}} : i32) + !$omp teams num_teams(dims(3): 8, 4, 2) + call f1() + ! CHECK: omp.terminator + !$omp end teams +end subroutine teams_numteams_dims3 + +! CHECK-LABEL: func @_QPteams_numteams_dims_var +subroutine teams_numteams_dims_var(a, b, c) + integer, intent(in) :: a, b, c + ! CHECK: omp.teams + ! CHECK-SAME: num_teams(dims(3): %{{.*}}, %{{.*}}, %{{.*}} : i32) + !$omp teams num_teams(dims(3): a, b, c) + call f1() + ! CHECK: omp.terminator + !$omp end teams +end subroutine teams_numteams_dims_var + +!=============================================================================== +! `num_teams` clause with lower bound (legacy, without dims) +!=============================================================================== + +! CHECK-LABEL: func @_QPteams_numteams_lower_upper +subroutine teams_numteams_lower_upper(lower, upper) + integer, intent(in) :: lower, upper + ! CHECK: omp.teams + ! CHECK-SAME: num_teams(%{{.*}} : i32 to %{{.*}} : i32) + !$omp teams num_teams(lower: upper) + call f1() + ! CHECK: omp.terminator + !$omp end teams +end subroutine teams_numteams_lower_upper + _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
