https://github.com/kevinsala updated https://github.com/llvm/llvm-project/pull/206412
>From 1d634ebc4175ea384acb654524911f1d583b0271 Mon Sep 17 00:00:00 2001 From: Kevin Sala <[email protected]> Date: Sun, 28 Jun 2026 14:14:36 -0700 Subject: [PATCH 1/2] [Clang][OpenMP] Add parsing for dims modifier in num_teams and thread_limit --- clang/include/clang/AST/OpenMPClause.h | 64 ++++- .../clang/Basic/DiagnosticSemaKinds.td | 15 +- clang/include/clang/Basic/OpenMPKinds.def | 8 + clang/include/clang/Basic/OpenMPKinds.h | 6 + clang/include/clang/Sema/SemaOpenMP.h | 32 ++- clang/lib/AST/OpenMPClause.cpp | 29 ++- clang/lib/AST/StmtProfile.cpp | 3 + clang/lib/Basic/OpenMPKinds.cpp | 21 +- clang/lib/Parse/ParseOpenMP.cpp | 84 +++++-- clang/lib/Sema/SemaOpenMP.cpp | 231 ++++++++++++------ clang/lib/Sema/TreeTransform.h | 34 ++- clang/lib/Serialization/ASTReader.cpp | 3 + clang/lib/Serialization/ASTWriter.cpp | 3 + clang/test/OpenMP/dims_modifier_ast_print.cpp | 40 +++ clang/test/OpenMP/dims_modifier_messages.cpp | 124 ++++++++++ clang/test/OpenMP/ompx_bare_messages.c | 2 +- ...et_teams_distribute_num_teams_messages.cpp | 16 +- ...ribute_parallel_for_num_teams_messages.cpp | 8 +- .../test/OpenMP/teams_num_teams_messages.cpp | 20 +- clang/tools/libclang/CIndex.cpp | 2 + 20 files changed, 596 insertions(+), 149 deletions(-) create mode 100644 clang/test/OpenMP/dims_modifier_ast_print.cpp create mode 100644 clang/test/OpenMP/dims_modifier_messages.cpp diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h index 12959fe0c2ff0..bd58a41d4df1e 100644 --- a/clang/include/clang/AST/OpenMPClause.h +++ b/clang/include/clang/AST/OpenMPClause.h @@ -6982,6 +6982,13 @@ class OMPMapClause final : public OMPMappableExprListClause<OMPMapClause>, /// single expression 'n' as upper-bound and modifier expression 'm' as /// lower-bound. /// +/// \code +/// #pragma omp teams num_teams(dims(2): x, y) +/// \endcode +/// In this example directive '#pragma omp teams' has clause 'num_teams' with +/// the 'dims' modifier specifying two dimensions. The list specifies the number +/// of teams in each dimension. +/// /// When 'ompx_bare' clause exists on a 'target' directive, 'num_teams' clause /// can accept up to three expressions. /// @@ -7061,6 +7068,13 @@ class OMPNumTeamsClause final /// Set the expression of the modifier. void setModifierExpr(Expr *E) { *varlist_end() = E; } + /// Get the expression of the modifier if it is the dims modifier. + const Expr *getDimsModifierExpr() const { + if (Modifier == OMPC_NUMTEAMS_dims) + return getModifierExpr(); + return nullptr; + } + /// Get the location of the modifier. SourceLocation getModifierLoc() const { return ModifierLoc; } @@ -7097,6 +7111,13 @@ class OMPNumTeamsClause final /// In this example directive '#pragma omp teams' has clause 'thread_limit' /// with single expression 'n'. /// +/// \code +/// #pragma omp teams thread_limit(dims(2): x, y) +/// \endcode +/// In this example directive '#pragma omp teams' has clause 'thread_limit' with +/// the 'dims' modifier specifying two dimensions. The list specifies the limit +/// on the number of threads in each dimension. +/// /// When 'ompx_bare' clause exists on a 'target' directive, 'thread_limit' /// clause can accept up to three expressions. /// @@ -7110,6 +7131,12 @@ class OMPThreadLimitClause final friend OMPVarListClause; friend TrailingObjects; + /// Modifier that was specified. + OpenMPThreadLimitClauseModifier Modifier = OMPC_THREADLIMIT_unknown; + + /// Location of the modifier. + SourceLocation ModifierLoc; + OMPThreadLimitClause(const ASTContext &C, SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc, unsigned N) @@ -7131,11 +7158,16 @@ class OMPThreadLimitClause final /// \param LParenLoc Location of '('. /// \param EndLoc Ending location of the clause. /// \param VL List of references to the variables. + /// \param Modifier The modifier specified in the clause. + /// \param ModifierExpr The expression of the modifier. + /// \param ModifierLoc Location of the modifier. /// \param PreInit static OMPThreadLimitClause * Create(const ASTContext &C, OpenMPDirectiveKind CaptureRegion, SourceLocation StartLoc, SourceLocation LParenLoc, - SourceLocation EndLoc, ArrayRef<Expr *> VL, Stmt *PreInit); + SourceLocation EndLoc, ArrayRef<Expr *> VL, + OpenMPThreadLimitClauseModifier Modifier, Expr *ModifierExpr, + SourceLocation ModifierLoc, Stmt *PreInit); /// Creates an empty clause with \a N variables. /// @@ -7151,9 +7183,37 @@ class OMPThreadLimitClause final return const_cast<OMPThreadLimitClause *>(this)->getThreadLimit(); } + /// Get the modifier. + OpenMPThreadLimitClauseModifier getModifier() const { return Modifier; } + + /// Set the modifier. + void setModifier(OpenMPThreadLimitClauseModifier M) { Modifier = M; } + + /// Get the expression of the modifier. + const Expr *getModifierExpr() const { return *varlist_end(); } + + /// Get the expression of the modifier. + Expr *getModifierExpr() { return *varlist_end(); } + + /// Set the expression of the modifier. + void setModifierExpr(Expr *E) { *varlist_end() = E; } + + /// Get the expression of the modifier if it is the dims modifier. + const Expr *getDimsModifierExpr() const { + if (Modifier == OMPC_THREADLIMIT_dims) + return getModifierExpr(); + return nullptr; + } + + /// Get the location of the modifier. + SourceLocation getModifierLoc() const { return ModifierLoc; } + + /// Set the location of the modifier. + void setModifierLoc(SourceLocation Loc) { ModifierLoc = Loc; } + child_range children() { return child_range(reinterpret_cast<Stmt **>(varlist_begin()), - reinterpret_cast<Stmt **>(varlist_end())); + reinterpret_cast<Stmt **>(varlist_end()) + 1); } const_child_range children() const { diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td index 7360c9bbab60a..22dcaed45e598 100644 --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -12792,9 +12792,11 @@ def warn_omp_unterminated_declare_target : Warning< "expected '#pragma omp end declare target' at end of file to match '#pragma omp %0'">, InGroup<SourceUsesOpenMP>; def err_ompx_bare_no_grid : Error< - "'ompx_bare' clauses requires explicit grid size via 'num_teams' and 'thread_limit' clauses">; -def err_omp_multi_expr_not_allowed: Error<"only one expression allowed in '%0' clause">; -def err_ompx_more_than_three_expr_not_allowed: Error<"at most three expressions are allowed in '%0' clause in 'target teams ompx_bare' construct">; + "'ompx_bare' clause requires explicit grid size via 'num_teams' and 'thread_limit' clauses">; +def err_ompx_bare_no_dims : Error< + "'ompx_bare' clause cannot be specified with 'dims' modifier in 'num_teams' and 'thread_limit' clauses">; +def err_omp_multi_expr_not_allowed : Error<"only one expression allowed in '%0' clause">; +def err_omp_max_three_exprs: Error<"maximum three expressions are supported in '%0' clause">; def err_omp_transparent_invalid_value : Error<"invalid value for transparent clause," " expected one of: omp_not_impex, omp_import, omp_export, omp_impex">; def err_omp_transparent_invalid_type : Error< @@ -12803,6 +12805,12 @@ def err_omp_num_teams_lower_bound_larger : Error<"lower bound is greater than upper bound in 'num_teams' clause">; def err_omp_modifier_requires_version : Error< "'%0' modifier in '%1' clause requires OpenMP %2 or later">; +def err_omp_unexpected_num_exprs + : Error<"unexpected number of expressions in '%0' clause">; +def err_omp_max_supported_dims + : Error<"at most three dimensions are supported in '%0' clause">; +def err_omp_incompatible_modifiers + : Error<"'%0' modifier cannot be specified with '%1' modifier in '%2' clause">; } // end of OpenMP category let CategoryName = "Related Result Type Issue" in { @@ -12829,7 +12837,6 @@ def note_related_result_type_explicit : Note< "%select{| and is expected to return an instance of its class type}0">; def err_invalid_type_for_program_scope_var : Error< "the %0 type cannot be used to declare a program scope variable">; - } let CategoryName = "Modules Issue" in { diff --git a/clang/include/clang/Basic/OpenMPKinds.def b/clang/include/clang/Basic/OpenMPKinds.def index 079ff4a583f9f..5c396a5d2d4b3 100644 --- a/clang/include/clang/Basic/OpenMPKinds.def +++ b/clang/include/clang/Basic/OpenMPKinds.def @@ -101,6 +101,9 @@ #ifndef OPENMP_NUMTHREADS_MODIFIER #define OPENMP_NUMTHREADS_MODIFIER(Name) #endif +#ifndef OPENMP_THREADLIMIT_MODIFIER +#define OPENMP_THREADLIMIT_MODIFIER(Name) +#endif #ifndef OPENMP_DOACROSS_MODIFIER #define OPENMP_DOACROSS_MODIFIER(Name) #endif @@ -273,10 +276,14 @@ OPENMP_NUMTASKS_MODIFIER(strict) // Modifiers for the 'num_teams' clause. OPENMP_NUMTEAMS_MODIFIER(lower_bound) +OPENMP_NUMTEAMS_MODIFIER(dims) // Modifiers for the 'num_tasks' clause. OPENMP_NUMTHREADS_MODIFIER(strict) +// Modifiers for the 'thread_limit' clause. +OPENMP_THREADLIMIT_MODIFIER(dims) + // Modifiers for 'allocate' clause. OPENMP_ALLOCATE_MODIFIER(allocator) OPENMP_ALLOCATE_MODIFIER(align) @@ -301,6 +308,7 @@ OPENMP_USE_DEVICE_PTR_FALLBACK_MODIFIER(fb_preserve) #undef OPENMP_NUMTASKS_MODIFIER #undef OPENMP_NUMTEAMS_MODIFIER #undef OPENMP_NUMTHREADS_MODIFIER +#undef OPENMP_THREADLIMIT_MODIFIER #undef OPENMP_DYN_GROUPPRIVATE_MODIFIER #undef OPENMP_DYN_GROUPPRIVATE_FALLBACK_MODIFIER #undef OPENMP_GRAINSIZE_MODIFIER diff --git a/clang/include/clang/Basic/OpenMPKinds.h b/clang/include/clang/Basic/OpenMPKinds.h index 3ee6cb83a431e..36c388668a455 100644 --- a/clang/include/clang/Basic/OpenMPKinds.h +++ b/clang/include/clang/Basic/OpenMPKinds.h @@ -274,6 +274,12 @@ enum OpenMPNumThreadsClauseModifier { OMPC_NUMTHREADS_unknown }; +enum OpenMPThreadLimitClauseModifier { +#define OPENMP_THREADLIMIT_MODIFIER(Name) OMPC_THREADLIMIT_##Name, +#include "clang/Basic/OpenMPKinds.def" + OMPC_THREADLIMIT_unknown +}; + /// OpenMP dependence types for 'doacross' clause. enum OpenMPDoacrossClauseModifier { #define OPENMP_DOACROSS_MODIFIER(Name) OMPC_DOACROSS_##Name, diff --git a/clang/include/clang/Sema/SemaOpenMP.h b/clang/include/clang/Sema/SemaOpenMP.h index f689e227866a7..726ac8539c66b 100644 --- a/clang/include/clang/Sema/SemaOpenMP.h +++ b/clang/include/clang/Sema/SemaOpenMP.h @@ -1188,10 +1188,15 @@ class SemaOpenMP : public SemaBase { SourceLocation RLoc; CXXScopeSpec ReductionOrMapperIdScopeSpec; DeclarationNameInfo ReductionOrMapperId; - int ExtraModifier = - -1; ///< Additional modifier for linear, map, depend, - ///< lastprivate, use_device_ptr, or num_teams clause. - Expr *ExtraModifierExpr = nullptr; + SmallVector<int, 2> ExtraModifierArray = {-1, -1}; + SmallVector<Expr *, 2> ExtraModifierExprArray = {nullptr, nullptr}; + SmallVector<SourceLocation, 2> ExtraModifierLocArray = {SourceLocation(), + SourceLocation()}; + /// Additional modifier for linear, map, depend, lastprivate, + /// use_device_ptr, or num_teams clause. + int &ExtraModifier = ExtraModifierArray[0]; + Expr *&ExtraModifierExpr = ExtraModifierExprArray[0]; + SourceLocation &ExtraModifierLoc = ExtraModifierLocArray[0]; int OriginalSharingModifier = 0; // Default is shared int NeedDevicePtrModifier = 0; SourceLocation NeedDevicePtrModifierLoc; @@ -1203,7 +1208,6 @@ class SemaOpenMP : public SemaBase { MotionModifiers; SmallVector<SourceLocation, NumberOfOMPMotionModifiers> MotionModifiersLoc; bool IsMapTypeImplicit = false; - SourceLocation ExtraModifierLoc; SourceLocation OriginalSharingModifierLoc; SourceLocation OmpAllMemoryLoc; SourceLocation @@ -1345,13 +1349,15 @@ class SemaOpenMP : public SemaBase { /// Called on well-formed 'num_teams' clause. OMPClause *ActOnOpenMPNumTeamsClause( ArrayRef<Expr *> VarList, OpenMPNumTeamsClauseModifier Modifier, - Expr *ModifierExpr, SourceLocation ModifierLoc, SourceLocation StartLoc, + Expr *ModifierExpr, SourceLocation ModifierLoc, + OpenMPNumTeamsClauseModifier ModifierExtra, Expr *ModifierExtraExpr, + SourceLocation ModifierExtraLoc, SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc); /// Called on well-formed 'thread_limit' clause. - OMPClause *ActOnOpenMPThreadLimitClause(ArrayRef<Expr *> VarList, - SourceLocation StartLoc, - SourceLocation LParenLoc, - SourceLocation EndLoc); + OMPClause *ActOnOpenMPThreadLimitClause( + ArrayRef<Expr *> VarList, OpenMPThreadLimitClauseModifier Modifier, + Expr *ModifierExpr, SourceLocation ModifierLoc, SourceLocation StartLoc, + SourceLocation LParenLoc, SourceLocation EndLoc); /// Called on well-formed 'priority' clause. OMPClause *ActOnOpenMPPriorityClause(Expr *Priority, SourceLocation StartLoc, SourceLocation LParenLoc, @@ -1484,6 +1490,12 @@ class SemaOpenMP : public SemaBase { SourceLocation LLoc, SourceLocation RLoc, ArrayRef<OMPIteratorData> Data); + ExprResult ActOnOpenMPDimsModifier(OpenMPClauseKind Kind, int Modifier, + Expr *ModifierExpr, + SourceLocation ModifierLoc, + ArrayRef<Expr *> VarList, + SourceLocation VarListEndLoc); + void handleOMPAssumeAttr(Decl *D, const ParsedAttr &AL); /// Setter and getter functions for device_num. diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp index f3e548e898b39..d451255bf5845 100644 --- a/clang/lib/AST/OpenMPClause.cpp +++ b/clang/lib/AST/OpenMPClause.cpp @@ -1978,18 +1978,24 @@ OMPNumTeamsClause *OMPNumTeamsClause::CreateEmpty(const ASTContext &C, OMPThreadLimitClause *OMPThreadLimitClause::Create( const ASTContext &C, OpenMPDirectiveKind CaptureRegion, SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc, - ArrayRef<Expr *> VL, Stmt *PreInit) { - void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(VL.size())); + ArrayRef<Expr *> VL, OpenMPThreadLimitClauseModifier Modifier, + Expr *ModifierExpr, SourceLocation ModifierLoc, Stmt *PreInit) { + // Reserve space for an extra modifier expression. + void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(VL.size() + 1)); OMPThreadLimitClause *Clause = new (Mem) OMPThreadLimitClause(C, StartLoc, LParenLoc, EndLoc, VL.size()); Clause->setVarRefs(VL); + Clause->setModifier(Modifier); + Clause->setModifierExpr(ModifierExpr); + Clause->setModifierLoc(ModifierLoc); Clause->setPreInitStmt(PreInit, CaptureRegion); return Clause; } OMPThreadLimitClause *OMPThreadLimitClause::CreateEmpty(const ASTContext &C, unsigned N) { - void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(N)); + // Reserve space for an extra modifier expression. + void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(N + 1)); return new (Mem) OMPThreadLimitClause(N); } @@ -2373,9 +2379,13 @@ void OMPClausePrinter::VisitOMPDeviceClause(OMPDeviceClause *Node) { void OMPClausePrinter::VisitOMPNumTeamsClause(OMPNumTeamsClause *Node) { if (!Node->varlist_empty()) { OS << "num_teams"; - if (const Expr *LowerBound = Node->getModifierExpr()) { + if (Node->getModifier() != OMPC_NUMTEAMS_unknown) { OS << "("; - LowerBound->printPretty(OS, nullptr, Policy, 0); + if (Node->getModifier() == OMPC_NUMTEAMS_dims) + OS << "dims("; + Node->getModifierExpr()->printPretty(OS, nullptr, Policy, 0); + if (Node->getModifier() == OMPC_NUMTEAMS_dims) + OS << ")"; VisitOMPClauseList(Node, ':'); } else { VisitOMPClauseList(Node, '('); @@ -2387,7 +2397,14 @@ void OMPClausePrinter::VisitOMPNumTeamsClause(OMPNumTeamsClause *Node) { void OMPClausePrinter::VisitOMPThreadLimitClause(OMPThreadLimitClause *Node) { if (!Node->varlist_empty()) { OS << "thread_limit"; - VisitOMPClauseList(Node, '('); + if (Node->getModifier() == OMPC_THREADLIMIT_dims) { + OS << "(dims("; + Node->getModifierExpr()->printPretty(OS, nullptr, Policy, 0); + OS << ")"; + VisitOMPClauseList(Node, ':'); + } else { + VisitOMPClauseList(Node, '('); + } OS << ")"; } } diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp index a7e7006c98a1b..291c72385518e 100644 --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -920,6 +920,9 @@ void OMPClauseProfiler::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) { } void OMPClauseProfiler::VisitOMPThreadLimitClause( const OMPThreadLimitClause *C) { + Profiler->VisitInteger(C->getModifier()); + if (const Expr *Modifier = C->getModifierExpr()) + Profiler->VisitStmt(Modifier); VisitOMPClauseList(C); VisitOMPClauseWithPreInit(C); } diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp index c590113578081..890ed0d29a7c6 100644 --- a/clang/lib/Basic/OpenMPKinds.cpp +++ b/clang/lib/Basic/OpenMPKinds.cpp @@ -224,6 +224,15 @@ unsigned clang::getOpenMPSimpleClauseType(OpenMPClauseKind Kind, StringRef Str, return OMPC_NUMTEAMS_unknown; return Type; } + case OMPC_thread_limit: { + unsigned Type = llvm::StringSwitch<unsigned>(Str) +#define OPENMP_THREADLIMIT_MODIFIER(Name) .Case(#Name, OMPC_THREADLIMIT_##Name) +#include "clang/Basic/OpenMPKinds.def" + .Default(OMPC_THREADLIMIT_unknown); + if (LangOpts.OpenMP < 61) + return OMPC_THREADLIMIT_unknown; + return Type; + } case OMPC_allocate: return llvm::StringSwitch<OpenMPAllocateClauseModifier>(Str) #define OPENMP_ALLOCATE_MODIFIER(Name) .Case(#Name, OMPC_ALLOCATE_##Name) @@ -294,7 +303,6 @@ unsigned clang::getOpenMPSimpleClauseType(OpenMPClauseKind Kind, StringRef Str, case OMPC_relaxed: case OMPC_threads: case OMPC_simd: - case OMPC_thread_limit: case OMPC_priority: case OMPC_nogroup: case OMPC_hint: @@ -606,6 +614,16 @@ const char *clang::getOpenMPSimpleClauseTypeName(OpenMPClauseKind Kind, #include "clang/Basic/OpenMPKinds.def" } llvm_unreachable("Invalid OpenMP 'num_teams' clause modifier"); + case OMPC_thread_limit: + switch (Type) { + case OMPC_THREADLIMIT_unknown: + return "unknown"; +#define OPENMP_THREADLIMIT_MODIFIER(Name) \ + case OMPC_THREADLIMIT_##Name: \ + return #Name; +#include "clang/Basic/OpenMPKinds.def" + } + llvm_unreachable("Invalid OpenMP 'thread_limit' clause modifier"); case OMPC_allocate: switch (Type) { case OMPC_ALLOCATE_unknown: @@ -683,7 +701,6 @@ const char *clang::getOpenMPSimpleClauseTypeName(OpenMPClauseKind Kind, case OMPC_relaxed: case OMPC_threads: case OMPC_simd: - case OMPC_thread_limit: case OMPC_priority: case OMPC_nogroup: case OMPC_hint: diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp index 62715da7966d0..2206d153badfd 100644 --- a/clang/lib/Parse/ParseOpenMP.cpp +++ b/clang/lib/Parse/ParseOpenMP.cpp @@ -5289,29 +5289,71 @@ bool Parser::ParseOpenMPVarList(OpenMPDirectiveKind DKind, Diag(Tok, diag::err_modifier_expected_colon) << "fallback"; } } - } else if (Kind == OMPC_num_teams) { - // Handle optional lower-bound modifier for num_teams clause. - Data.ExtraModifier = OMPC_NUMTEAMS_unknown; - TentativeParsingAction TPA(*this); - SourceLocation TLoc = Tok.getLocation(); - ExprResult FirstExpr = ParseAssignmentExpression(); - if (FirstExpr.isInvalid()) { - SkipUntil(tok::r_paren, tok::annot_pragma_openmp_end, StopBeforeMatch); - Data.RLoc = Tok.getLocation(); - if (!T.consumeClose()) - Data.RLoc = T.getCloseLocation(); - TPA.Commit(); - return true; + } else if (Kind == OMPC_num_teams || Kind == OMPC_thread_limit) { + int Mod = 0; + // Handle optional dims and lower-bound modifiers for num_teams clause, and + // the optional dims modifier for thread_limit clause. + Data.ExtraModifierArray[0] = Data.ExtraModifierArray[1] = + Kind == OMPC_num_teams ? static_cast<int>(OMPC_NUMTEAMS_unknown) + : static_cast<int>(OMPC_THREADLIMIT_unknown); + + // Lower-bound modifier is only accepted in num_teams. + bool CanParseLowerBoundModifier = (Kind == OMPC_num_teams); + if (!Tok.isAnnotation() && PP.getSpelling(Tok) == "dims" && + NextToken().is(tok::l_paren)) { + SourceLocation TLoc = Tok.getLocation(); + ConsumeToken(); + SourceLocation RLoc; + ExprResult ExprR = ParseOpenMPParensExpr(getOpenMPClauseName(Kind), RLoc); + if (ExprR.isUsable()) { + Data.ExtraModifierArray[Mod] = + Kind == OMPC_num_teams ? static_cast<int>(OMPC_NUMTEAMS_dims) + : static_cast<int>(OMPC_THREADLIMIT_dims); + Data.ExtraModifierExprArray[Mod] = ExprR.get(); + Data.ExtraModifierLocArray[Mod] = TLoc; + ++Mod; + } + + CanParseLowerBoundModifier &= Tok.is(tok::comma); + if (CanParseLowerBoundModifier || Tok.is(tok::colon)) { + ConsumeToken(); + } else { + Diag(Tok, diag::err_modifier_expected_colon) + << getOpenMPClauseName(Kind); + SkipUntil(tok::r_paren, tok::annot_pragma_openmp_end, StopBeforeMatch); + Data.RLoc = Tok.getLocation(); + if (!T.consumeClose()) + Data.RLoc = T.getCloseLocation(); + return true; + } } - if (Tok.is(tok::colon)) { - ConsumeToken(); - Data.ExtraModifier = OMPC_NUMTEAMS_lower_bound; - Data.ExtraModifierExpr = FirstExpr.get(); - Data.ExtraModifierLoc = TLoc; - TPA.Commit(); - } else { - TPA.Revert(); + // The lower bound modifier must appear as the last modifier. + if (CanParseLowerBoundModifier) { + TentativeParsingAction TPA(*this); + SourceLocation TLoc = Tok.getLocation(); + ExprResult FirstExpr = ParseAssignmentExpression(); + if (FirstExpr.isInvalid()) { + SkipUntil(tok::r_paren, tok::annot_pragma_openmp_end, StopBeforeMatch); + Data.RLoc = Tok.getLocation(); + if (!T.consumeClose()) + Data.RLoc = T.getCloseLocation(); + TPA.Commit(); + return true; + } + + if (Tok.is(tok::colon)) { + // Correctly parsed the lower bound modifier. + ConsumeToken(); + Data.ExtraModifierArray[Mod] = OMPC_NUMTEAMS_lower_bound; + Data.ExtraModifierExprArray[Mod] = FirstExpr.get(); + Data.ExtraModifierLocArray[Mod] = TLoc; + TPA.Commit(); + } else { + // Could not find the colon after the expression, revert it and let this + // function parse it as a list of expressions. + TPA.Revert(); + } } } diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp index 134b98d8e80cf..99925ecbe1acc 100644 --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -13621,23 +13621,69 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetUpdateDirective( Clauses, AStmt); } -/// This checks whether a \p ClauseType clause \p C has at most \p Max -/// expression. If not, a diag of number \p Diag will be emitted. -template <typename ClauseType> -static bool checkNumExprsInClause(SemaBase &SemaRef, - ArrayRef<OMPClause *> Clauses, - unsigned MaxNum, unsigned Diag) { - auto ClauseItr = llvm::find_if(Clauses, llvm::IsaPred<ClauseType>); - if (ClauseItr == Clauses.end()) - return true; - const auto *C = cast<ClauseType>(*ClauseItr); - auto VarList = C->getVarRefs(); - if (VarList.size() > MaxNum) { - SemaRef.Diag(VarList[MaxNum]->getBeginLoc(), Diag) - << getOpenMPClauseNameForDiag(C->getClauseKind()); +template <typename ClauseT> +static bool checkClauseNumExprs(SemaBase &SemaRef, const ClauseT *Clause, + const OMPXBareClause *BareClause) { + if (!Clause) return false; + + uint64_t MaxExprs = BareClause ? 3 : 1; + + const Expr *DimsExpr = Clause->getDimsModifierExpr(); + if (DimsExpr) { + // Cannot verify the size yet. + if (DimsExpr->isInstantiationDependent()) + return false; + + MaxExprs = + DimsExpr->EvaluateKnownConstInt(SemaRef.getASTContext()).getExtValue(); } - return true; + + size_t NumVars = Clause->getVarRefs().size(); + if (NumVars > MaxExprs) { + SemaRef.Diag(Clause->getBeginLoc(), diag::err_omp_unexpected_num_exprs) + << getOpenMPClauseName(Clause->getClauseKind()); + return true; + } + if (NumVars > 3) { + SemaRef.Diag(Clause->getBeginLoc(), diag::err_omp_max_three_exprs) + << getOpenMPClauseName(Clause->getClauseKind()); + return true; + } + return false; +} + +static bool checkNumExprsInClauses(SemaBase &SemaRef, + ArrayRef<OMPClause *> Clauses) { + auto BareClauseIt = llvm::find_if(Clauses, llvm::IsaPred<OMPXBareClause>); + auto ThreadLimitIt = + llvm::find_if(Clauses, llvm::IsaPred<OMPThreadLimitClause>); + auto NumTeamsIt = llvm::find_if(Clauses, llvm::IsaPred<OMPNumTeamsClause>); + + const auto *BareClause = BareClauseIt != Clauses.end() + ? cast<OMPXBareClause>(*BareClauseIt) + : nullptr; + const auto *ThreadLimitClause = + ThreadLimitIt != Clauses.end() + ? cast<OMPThreadLimitClause>(*ThreadLimitIt) + : nullptr; + const auto *NumTeamsClause = NumTeamsIt != Clauses.end() + ? cast<OMPNumTeamsClause>(*NumTeamsIt) + : nullptr; + + if (BareClause) { + if (!NumTeamsClause || !ThreadLimitClause) { + SemaRef.Diag(BareClause->getBeginLoc(), diag::err_ompx_bare_no_grid); + return true; + } + if (ThreadLimitClause->getModifier() == OMPC_THREADLIMIT_dims || + NumTeamsClause->getModifier() == OMPC_NUMTEAMS_dims) { + SemaRef.Diag(BareClause->getBeginLoc(), diag::err_ompx_bare_no_dims); + return true; + } + } + return checkClauseNumExprs(SemaRef, ThreadLimitClause, BareClause) || + checkClauseNumExprs(SemaRef, NumTeamsClause, BareClause); } StmtResult SemaOpenMP::ActOnOpenMPTeamsDirective(ArrayRef<OMPClause *> Clauses, @@ -13647,10 +13693,7 @@ StmtResult SemaOpenMP::ActOnOpenMPTeamsDirective(ArrayRef<OMPClause *> Clauses, if (!AStmt) return StmtError(); - if (!checkNumExprsInClause<OMPNumTeamsClause>( - *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed) || - !checkNumExprsInClause<OMPThreadLimitClause>( - *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed)) + if (checkNumExprsInClauses(*this, Clauses)) return StmtError(); // Report affected OpenMP target offloading behavior when in HIP lang-mode. @@ -14414,30 +14457,9 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDirective( setBranchProtectedScope(SemaRef, OMPD_target_teams, AStmt); - const OMPClause *BareClause = nullptr; - bool HasThreadLimitAndNumTeamsClause = hasClauses(Clauses, OMPC_num_teams) && - hasClauses(Clauses, OMPC_thread_limit); - bool HasBareClause = llvm::any_of(Clauses, [&](const OMPClause *C) { - BareClause = C; - return C->getClauseKind() == OMPC_ompx_bare; - }); - - if (HasBareClause && !HasThreadLimitAndNumTeamsClause) { - Diag(BareClause->getBeginLoc(), diag::err_ompx_bare_no_grid); + if (checkNumExprsInClauses(*this, Clauses)) return StmtError(); - } - unsigned ClauseMaxNumExprs = HasBareClause ? 3 : 1; - unsigned DiagNo = HasBareClause - ? diag::err_ompx_more_than_three_expr_not_allowed - : diag::err_omp_multi_expr_not_allowed; - - if (!checkNumExprsInClause<OMPNumTeamsClause>(*this, Clauses, - ClauseMaxNumExprs, DiagNo) || - !checkNumExprsInClause<OMPThreadLimitClause>(*this, Clauses, - ClauseMaxNumExprs, DiagNo)) { - return StmtError(); - } return OMPTargetTeamsDirective::Create(getASTContext(), StartLoc, EndLoc, Clauses, AStmt); } @@ -14448,10 +14470,7 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeDirective( if (!AStmt) return StmtError(); - if (!checkNumExprsInClause<OMPNumTeamsClause>( - *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed) || - !checkNumExprsInClause<OMPThreadLimitClause>( - *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed)) + if (checkNumExprsInClauses(*this, Clauses)) return StmtError(); CapturedStmt *CS = @@ -14480,10 +14499,7 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeParallelForDirective( if (!AStmt) return StmtError(); - if (!checkNumExprsInClause<OMPNumTeamsClause>( - *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed) || - !checkNumExprsInClause<OMPThreadLimitClause>( - *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed)) + if (checkNumExprsInClauses(*this, Clauses)) return StmtError(); CapturedStmt *CS = setBranchProtectedScope( @@ -14513,10 +14529,7 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeParallelForSimdDirective( if (!AStmt) return StmtError(); - if (!checkNumExprsInClause<OMPNumTeamsClause>( - *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed) || - !checkNumExprsInClause<OMPThreadLimitClause>( - *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed)) + if (checkNumExprsInClauses(*this, Clauses)) return StmtError(); CapturedStmt *CS = setBranchProtectedScope( @@ -14549,10 +14562,7 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeSimdDirective( if (!AStmt) return StmtError(); - if (!checkNumExprsInClause<OMPNumTeamsClause>( - *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed) || - !checkNumExprsInClause<OMPThreadLimitClause>( - *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed)) + if (checkNumExprsInClauses(*this, Clauses)) return StmtError(); CapturedStmt *CS = setBranchProtectedScope( @@ -17395,6 +17405,7 @@ ExprResult SemaOpenMP::VerifyPositiveIntegerConstantInClause( DSAStack->setAssociatedLoops(Result.getExtValue()); else if (CKind == OMPC_ordered) DSAStack->setAssociatedLoops(Result.getExtValue()); + return ICE; } @@ -19332,11 +19343,19 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind, assert(0 <= ExtraModifier && ExtraModifier <= OMPC_NUMTEAMS_unknown && "Unexpected num_teams modifier."); Res = ActOnOpenMPNumTeamsClause( - VarList, static_cast<OpenMPNumTeamsClauseModifier>(ExtraModifier), - ExtraModifierExpr, ExtraModifierLoc, StartLoc, LParenLoc, EndLoc); + VarList, + static_cast<OpenMPNumTeamsClauseModifier>(Data.ExtraModifierArray[0]), + Data.ExtraModifierExprArray[0], Data.ExtraModifierLocArray[0], + static_cast<OpenMPNumTeamsClauseModifier>(Data.ExtraModifierArray[1]), + Data.ExtraModifierExprArray[1], Data.ExtraModifierLocArray[1], StartLoc, + LParenLoc, EndLoc); break; case OMPC_thread_limit: - Res = ActOnOpenMPThreadLimitClause(VarList, StartLoc, LParenLoc, EndLoc); + assert(0 <= ExtraModifier && ExtraModifier <= OMPC_THREADLIMIT_unknown && + "Unexpected num_teams modifier."); + Res = ActOnOpenMPThreadLimitClause( + VarList, static_cast<OpenMPThreadLimitClauseModifier>(ExtraModifier), + ExtraModifierExpr, ExtraModifierLoc, StartLoc, LParenLoc, EndLoc); break; case OMPC_if: case OMPC_depobj: @@ -24401,9 +24420,45 @@ const ValueDecl *SemaOpenMP::getOpenMPDeclareMapperVarName() const { return cast<DeclRefExpr>(DSAStack->getDeclareMapperVarRef())->getDecl(); } +ExprResult SemaOpenMP::ActOnOpenMPDimsModifier(OpenMPClauseKind ClauseKind, + int Modifier, Expr *ModifierExpr, + SourceLocation ModifierLoc, + ArrayRef<Expr *> VarList, + SourceLocation VarListEndLoc) { + assert(ModifierExpr && "Unexpected modifier expression."); + + if (getLangOpts().OpenMP < 61) { + Diag(ModifierLoc, diag::err_omp_modifier_requires_version) + << getOpenMPSimpleClauseTypeName(ClauseKind, Modifier) + << getOpenMPClauseName(ClauseKind) << "6.1"; + return ExprError(); + } + + ExprResult DimsRes = VerifyPositiveIntegerConstantInClause( + ModifierExpr, ClauseKind, /*StrictlyPositive=*/true, + /*SuppressExprDiags=*/false); + if (DimsRes.isInvalid()) + return ExprError(); + + ModifierExpr = DimsRes.get(); + if (ModifierExpr->isInstantiationDependent()) + return DimsRes; + + uint64_t NumDims = + ModifierExpr->EvaluateKnownConstInt(getASTContext()).getExtValue(); + if (NumDims == VarList.size()) + return DimsRes; + + Diag(VarListEndLoc, diag::err_omp_unexpected_num_exprs) + << getOpenMPClauseName(ClauseKind); + return ExprError(); +} + OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause( ArrayRef<Expr *> VarList, OpenMPNumTeamsClauseModifier Modifier, - Expr *ModifierExpr, SourceLocation ModifierLoc, SourceLocation StartLoc, + Expr *ModifierExpr, SourceLocation ModifierLoc, + OpenMPNumTeamsClauseModifier ModifierExtra, Expr *, + SourceLocation ModifierExtraLoc, SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc) { if (VarList.empty()) return nullptr; @@ -24417,10 +24472,27 @@ OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause( } // OpenMP [teams Construct, Restrictions] - // The lower-bound expression in num_teams must evaluate to a positive integer - // value. - if (ModifierExpr) { - assert(Modifier == OMPC_NUMTEAMS_lower_bound && "Unexpected modifier."); + // The lower-bound modifier cannot be specified if the dims modifier is + // specified. + if (Modifier != OMPC_NUMTEAMS_unknown && + ModifierExtra != OMPC_NUMTEAMS_unknown) { + Diag(ModifierExtraLoc, diag::err_omp_incompatible_modifiers) + << getOpenMPSimpleClauseTypeName(llvm::omp::OMPC_num_teams, + ModifierExtra) + << getOpenMPSimpleClauseTypeName(llvm::omp::OMPC_num_teams, Modifier) + << getOpenMPClauseName(llvm::omp::OMPC_num_teams); + ModifierExtra = OMPC_NUMTEAMS_unknown; + ModifierExtraLoc = SourceLocation(); + } + + if (Modifier == OMPC_NUMTEAMS_dims) { + ExprResult Res = ActOnOpenMPDimsModifier( + OMPC_num_teams, Modifier, ModifierExpr, ModifierLoc, VarList, EndLoc); + if (Res.isInvalid()) + return nullptr; + ModifierExpr = Res.get(); + } else if (Modifier == OMPC_NUMTEAMS_lower_bound) { + assert(ModifierExpr && "Unexpected modifier expression."); if (getLangOpts().OpenMP < 51) { Diag(ModifierLoc, diag::err_omp_modifier_requires_version) @@ -24429,6 +24501,9 @@ OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause( return nullptr; } + // OpenMP [teams Construct, Restrictions] + // The lower-bound expression in num_teams must evaluate to a positive + // integer value. if (!isNonNegativeIntegerValue(ModifierExpr, SemaRef, OMPC_num_teams, /*StrictlyPositive=*/true)) return nullptr; @@ -24485,10 +24560,10 @@ OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause( ModifierExpr, ModifierLoc, PreInit); } -OMPClause *SemaOpenMP::ActOnOpenMPThreadLimitClause(ArrayRef<Expr *> VarList, - SourceLocation StartLoc, - SourceLocation LParenLoc, - SourceLocation EndLoc) { +OMPClause *SemaOpenMP::ActOnOpenMPThreadLimitClause( + ArrayRef<Expr *> VarList, OpenMPThreadLimitClauseModifier Modifier, + Expr *ModifierExpr, SourceLocation ModifierLoc, SourceLocation StartLoc, + SourceLocation LParenLoc, SourceLocation EndLoc) { if (VarList.empty()) return nullptr; @@ -24500,12 +24575,22 @@ OMPClause *SemaOpenMP::ActOnOpenMPThreadLimitClause(ArrayRef<Expr *> VarList, return nullptr; } + if (Modifier == OMPC_THREADLIMIT_dims) { + ExprResult Res = + ActOnOpenMPDimsModifier(OMPC_thread_limit, Modifier, ModifierExpr, + ModifierLoc, VarList, EndLoc); + if (Res.isInvalid()) + return nullptr; + ModifierExpr = Res.get(); + } + OpenMPDirectiveKind DKind = DSAStack->getCurrentDirective(); OpenMPDirectiveKind CaptureRegion = getOpenMPCaptureRegionForClause( DKind, OMPC_thread_limit, getLangOpts().OpenMP); if (CaptureRegion == OMPD_unknown || SemaRef.CurContext->isDependentContext()) return OMPThreadLimitClause::Create(getASTContext(), CaptureRegion, StartLoc, LParenLoc, EndLoc, VarList, + Modifier, ModifierExpr, ModifierLoc, /*PreInit=*/nullptr); llvm::MapVector<const Expr *, DeclRefExpr *> Captures; @@ -24516,9 +24601,15 @@ OMPClause *SemaOpenMP::ActOnOpenMPThreadLimitClause(ArrayRef<Expr *> VarList, Vars.push_back(ValExpr); } + if (ModifierExpr) { + ModifierExpr = SemaRef.MakeFullExpr(ModifierExpr).get(); + ModifierExpr = tryBuildCapture(SemaRef, ModifierExpr, Captures).get(); + } + Stmt *PreInit = buildPreInits(getASTContext(), Captures); return OMPThreadLimitClause::Create(getASTContext(), CaptureRegion, StartLoc, - LParenLoc, EndLoc, Vars, PreInit); + LParenLoc, EndLoc, Vars, Modifier, + ModifierExpr, ModifierLoc, PreInit); } OMPClause *SemaOpenMP::ActOnOpenMPPriorityClause(Expr *Priority, diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h index 3c8fcbe582b43..92df2e622a37c 100644 --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -2146,23 +2146,26 @@ class TreeTransform { /// Subclasses may override this routine to provide different behavior. OMPClause *RebuildOMPNumTeamsClause( ArrayRef<Expr *> VarList, OpenMPNumTeamsClauseModifier Modifier, - Expr *ModifierExpr, SourceLocation ModifierLoc, SourceLocation StartLoc, + Expr *ModifierExpr, SourceLocation ModifierLoc, + OpenMPNumTeamsClauseModifier ModifierExtra, Expr *ModifierExtraExpr, + SourceLocation ModifierExtraLoc, SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc) { return getSema().OpenMP().ActOnOpenMPNumTeamsClause( - VarList, Modifier, ModifierExpr, ModifierLoc, StartLoc, LParenLoc, - EndLoc); + VarList, Modifier, ModifierExpr, ModifierLoc, ModifierExtra, + ModifierExtraExpr, ModifierExtraLoc, StartLoc, LParenLoc, EndLoc); } /// Build a new OpenMP 'thread_limit' clause. /// /// By default, performs semantic analysis to build the new statement. /// Subclasses may override this routine to provide different behavior. - OMPClause *RebuildOMPThreadLimitClause(ArrayRef<Expr *> VarList, - SourceLocation StartLoc, - SourceLocation LParenLoc, - SourceLocation EndLoc) { - return getSema().OpenMP().ActOnOpenMPThreadLimitClause(VarList, StartLoc, - LParenLoc, EndLoc); + OMPClause *RebuildOMPThreadLimitClause( + ArrayRef<Expr *> VarList, OpenMPThreadLimitClauseModifier Modifier, + Expr *ModifierExpr, SourceLocation ModifierLoc, SourceLocation StartLoc, + SourceLocation LParenLoc, SourceLocation EndLoc) { + return getSema().OpenMP().ActOnOpenMPThreadLimitClause( + VarList, Modifier, ModifierExpr, ModifierLoc, StartLoc, LParenLoc, + EndLoc); } /// Build a new OpenMP 'priority' clause. @@ -11611,7 +11614,8 @@ TreeTransform<Derived>::TransformOMPNumTeamsClause(OMPNumTeamsClause *C) { } return getDerived().RebuildOMPNumTeamsClause( Vars, C->getModifier(), ModifierExpr, C->getModifierLoc(), - C->getBeginLoc(), C->getLParenLoc(), C->getEndLoc()); + OMPC_NUMTEAMS_unknown, nullptr, SourceLocation(), C->getBeginLoc(), + C->getLParenLoc(), C->getEndLoc()); } template <typename Derived> @@ -11625,8 +11629,16 @@ TreeTransform<Derived>::TransformOMPThreadLimitClause(OMPThreadLimitClause *C) { return nullptr; Vars.push_back(EVar.get()); } + Expr *ModifierExpr = C->getModifierExpr(); + if (ModifierExpr) { + ExprResult EVar = getDerived().TransformExpr(cast<Expr>(ModifierExpr)); + if (EVar.isInvalid()) + return nullptr; + ModifierExpr = EVar.get(); + } return getDerived().RebuildOMPThreadLimitClause( - Vars, C->getBeginLoc(), C->getLParenLoc(), C->getEndLoc()); + Vars, C->getModifier(), ModifierExpr, C->getModifierLoc(), + C->getBeginLoc(), C->getLParenLoc(), C->getEndLoc()); } template <typename Derived> diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp index b01b18fe3e0ec..3189bf3431c1e 100644 --- a/clang/lib/Serialization/ASTReader.cpp +++ b/clang/lib/Serialization/ASTReader.cpp @@ -12608,6 +12608,9 @@ void OMPClauseReader::VisitOMPNumTeamsClause(OMPNumTeamsClause *C) { } void OMPClauseReader::VisitOMPThreadLimitClause(OMPThreadLimitClause *C) { + C->setModifier(Record.readEnum<OpenMPThreadLimitClauseModifier>()); + C->setModifierLoc(Record.readSourceLocation()); + C->setModifierExpr(Record.readSubExpr()); VisitOMPClauseWithPreInit(C); C->setLParenLoc(Record.readSourceLocation()); unsigned NumVars = C->varlist_size(); diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp index ecf935e3b3548..a305e92b6a602 100644 --- a/clang/lib/Serialization/ASTWriter.cpp +++ b/clang/lib/Serialization/ASTWriter.cpp @@ -8599,6 +8599,9 @@ void OMPClauseWriter::VisitOMPNumTeamsClause(OMPNumTeamsClause *C) { void OMPClauseWriter::VisitOMPThreadLimitClause(OMPThreadLimitClause *C) { Record.push_back(C->varlist_size()); + Record.writeEnum(C->getModifier()); + Record.AddSourceLocation(C->getModifierLoc()); + Record.AddStmt(C->getModifierExpr()); VisitOMPClauseWithPreInit(C); Record.AddSourceLocation(C->getLParenLoc()); for (auto *VE : C->varlist()) diff --git a/clang/test/OpenMP/dims_modifier_ast_print.cpp b/clang/test/OpenMP/dims_modifier_ast_print.cpp new file mode 100644 index 0000000000000..d47e8353abcad --- /dev/null +++ b/clang/test/OpenMP/dims_modifier_ast_print.cpp @@ -0,0 +1,40 @@ +// RUN: %clang_cc1 -fopenmp -fopenmp-version=61 -ast-print %s | FileCheck %s +// expected-no-diagnostics + +void test_ast_print() { + int N = 10; + int M = 20; + constexpr int D = 3; + + // CHECK: #pragma omp target teams num_teams(dims(3):1,2,3) thread_limit(dims(2):10,20) + #pragma omp target teams num_teams(dims(3): 1, 2, 3) thread_limit(dims(2): 10, 20) + {} + + // CHECK: #pragma omp target teams num_teams(dims(D):1,2,3) thread_limit(dims(D):10,20,30) + #pragma omp target teams num_teams(dims(D): 1, 2, 3) thread_limit(dims(D): 10, 20, 30) + {} + + // CHECK: #pragma omp target teams num_teams(dims(2):N,N + 1) thread_limit(dims(1):M) + #pragma omp target teams num_teams(dims(2): N, N + 1) thread_limit(dims(1): M) + {} + + // CHECK: #pragma omp target teams num_teams(dims(1):N) thread_limit(dims(1):M) + #pragma omp target teams num_teams(dims(1): N) thread_limit(dims(1): M) + {} +} + +template <int D> +void template_test() { + int arr[3] = {1, 2, 3}; + // CHECK: #pragma omp target teams num_teams(dims(3):arr[0],arr[1],arr[2]) thread_limit(dims(1):3) + #pragma omp target teams num_teams(dims(3): arr[0], arr[1], arr[2]) thread_limit(dims(1): D) + {} + + // CHECK: #pragma omp target teams num_teams(dims(3):arr[0],arr[1],arr[2]) thread_limit(dims(2):3,arr[0]) + #pragma omp target teams num_teams(dims(D): arr[0], arr[1], arr[2]) thread_limit(dims(2): D, arr[0]) + {} +} + +void call_templates() { + template_test<3>(); +} diff --git a/clang/test/OpenMP/dims_modifier_messages.cpp b/clang/test/OpenMP/dims_modifier_messages.cpp new file mode 100644 index 0000000000000..680bb1d6d8d13 --- /dev/null +++ b/clang/test/OpenMP/dims_modifier_messages.cpp @@ -0,0 +1,124 @@ +// RUN: %clang_cc1 -verify -fopenmp -fopenmp-version=61 -triple x86_64-unknown-unknown %s +// RUN: %clang_cc1 -verify -fopenmp-simd -fopenmp-version=61 -triple x86_64-unknown-unknown %s +// RUN: %clang_cc1 -verify -fopenmp -fopenmp-version=61 -triple x86_64-unknown-unknown -fopenmp-targets=nvptx64 %s +// RUN: %clang_cc1 -verify -fopenmp -fopenmp-version=61 -triple x86_64-unknown-unknown -fopenmp-targets=amdgcn-amd-amdhsa %s +// RUN: %clang_cc1 -verify -fopenmp -fopenmp-version=52 -triple x86_64-unknown-unknown -DVERSION52 %s + +void foo() { +} + +#ifndef VERSION52 +void bar(int N) { // expected-note {{declared here}} + // 1. Invalid syntax of the dims modifier. + +#pragma omp target teams num_teams(dims 2: 4) // expected-error {{use of undeclared identifier 'dims'}} + foo(); + +#pragma omp target thread_limit(dim(2) 4, 5) + // expected-error@-1 {{use of undeclared identifier 'dim'}} + // expected-error@-2 {{expected ',' or ')' in 'thread_limit' clause}} + foo(); + +#pragma omp target thread_limit(dims((2): 4, 5) + // expected-error@-1 {{expected ')'}} + // expected-error@-2 {{expected ')'}} + // expected-note@-3 {{to match this '('}} + // expected-note@-4 {{to match this '('}} + // expected-error@-5 {{missing ':' after thread_limit modifier}} + foo(); + +#pragma omp target thread_limit(dims(2)): 4, 5) + // expected-error@-1 {{missing ':' after thread_limit modifier}} + // expected-warning@-2 {{extra tokens at the end of '#pragma omp target' are ignored}} + foo(); + +#pragma omp target thread_limit(dims(2) 4, 5) // expected-error {{missing ':' after thread_limit modifier}} + foo(); + +#pragma omp target teams distribute num_teams(dims(): 4) // expected-error {{expected expression}} + for (int i = 0; i < 10; ++i) {} + + // 2. Mismatching number of expressions. + +#pragma omp target teams num_teams(dims(2): 4) // expected-error {{unexpected number of expressions in 'num_teams' clause}} + foo(); + +#pragma omp target thread_limit(dims(1): 4, 5) // expected-error {{unexpected number of expressions in 'thread_limit' clause}} + foo(); + +#pragma omp target teams distribute num_teams(dims(3): 4, 5) // expected-error {{unexpected number of expressions in 'num_teams' clause}} + for (int i = 0; i < 10; ++i) {} + + // 3. Exceeding three dimensions. + +#pragma omp target teams num_teams(dims(4): 1, 2, 3, 4) // expected-error {{maximum three expressions are supported in 'num_teams' clause}} + foo(); + +#pragma omp target thread_limit(dims(2): 1, 2, 3, 4) // expected-error {{unexpected number of expressions in 'thread_limit' clause}} + foo(); + +#pragma omp target teams distribute thread_limit(dims(4): 1, 2, 3, 4) // expected-error {{maximum three expressions are supported in 'thread_limit' clause}} + for (int i = 0; i < 10; ++i) {} + + // 4. Invalid use of dims when ompx_bare is present. + +#pragma omp target teams ompx_bare num_teams(dims(2): 1, 2) thread_limit(1, 2, 3) // expected-error {{'ompx_bare' clause cannot be specified with 'dims' modifier in 'num_teams' and 'thread_limit' clauses}} + foo(); + +#pragma omp target teams ompx_bare num_teams(1, 2, 3) thread_limit(dims(3): 1, 2, 3) // expected-error {{'ompx_bare' clause cannot be specified with 'dims' modifier in 'num_teams' and 'thread_limit' clauses}} + foo(); + + // 5. Number of dimensions in dims is invalid. + +#pragma omp target teams num_teams(dims(N): 1, 2) + // expected-error@-1 {{expression is not an integral constant expression}} + // expected-note@-2 {{function parameter 'N' with unknown value cannot be used in a constant expression}} + foo(); + +#pragma omp target thread_limit(dims(2.5): 1, 2) // expected-error {{integral constant expression must have integral or unscoped enumeration type, not 'double'}} + foo(); + +#pragma omp target teams distribute num_teams(dims(0): 4) // expected-error {{argument to 'num_teams' clause must be a strictly positive integer value}} + for (int i = 0; i < 10; ++i) {} + +#pragma omp target teams thread_limit(dims(-1): 4) // expected-error {{argument to 'thread_limit' clause must be a strictly positive integer value}} + foo(); +} + +template <int D> +void template_test() { + // 7. Mismatching number of expressions with template arguments. + +#pragma omp target teams num_teams(dims(D): 4) + // expected-error@-1 {{unexpected number of expressions in 'num_teams' clause}} + // expected-error@-2 {{argument to 'num_teams' clause must be a strictly positive integer value}} + foo(); + +#pragma omp target thread_limit(dims(D): 4, 5) + // expected-error@-1 {{unexpected number of expressions in 'thread_limit' clause}} + // expected-error@-2 {{argument to 'thread_limit' clause must be a strictly positive integer value}} + foo(); + +#pragma omp target teams distribute num_teams(dims(D): 4, 5) + // expected-error@-1 {{unexpected number of expressions in 'num_teams' clause}} + // expected-error@-2 {{argument to 'num_teams' clause must be a strictly positive integer value}} + for (int i = 0; i < 10; ++i) {} +} + +void call_templates() { + template_test<3>(); // expected-note {{in instantiation of function template specialization 'template_test<3>' requested here}} + template_test<0>(); // expected-note {{in instantiation of function template specialization 'template_test<0>' requested here}} +} +#endif + +#ifdef VERSION52 +void version() { + // 6. Dims modifier requires OpenMP 6.1. + +#pragma omp target teams num_teams(dims(1): 4) // expected-error {{'dims' modifier in 'num_teams' clause requires OpenMP 6.1 or later}} + foo(); + +#pragma omp target thread_limit(dims(1): 4) // expected-error {{'dims' modifier in 'thread_limit' clause requires OpenMP 6.1 or later}} + foo(); +} +#endif diff --git a/clang/test/OpenMP/ompx_bare_messages.c b/clang/test/OpenMP/ompx_bare_messages.c index 19ceee5625fee..0fd931954c475 100644 --- a/clang/test/OpenMP/ompx_bare_messages.c +++ b/clang/test/OpenMP/ompx_bare_messages.c @@ -19,6 +19,6 @@ void bar() { #pragma omp teams ompx_bare // expected-error {{unexpected OpenMP clause 'ompx_bare' in directive '#pragma omp teams'}} expected-note {{OpenMP extension clause 'ompx_bare' only allowed with '#pragma omp target teams'}} foo(); -#pragma omp target teams ompx_bare // expected-error {{'ompx_bare' clauses requires explicit grid size via 'num_teams' and 'thread_limit' clauses}} +#pragma omp target teams ompx_bare // expected-error {{'ompx_bare' clause requires explicit grid size via 'num_teams' and 'thread_limit' clauses}} foo(); } diff --git a/clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp b/clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp index 8bf388f0b5da9..e353fc81d4eca 100644 --- a/clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp +++ b/clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp @@ -44,16 +44,16 @@ T tmain(T argc) { #pragma omp target teams distribute num_teams(3.14) // expected-error 2 {{expression must have integral or unscoped enumeration type, not 'double'}} for (int i=0; i<100; i++) foo(); -#pragma omp target teams distribute num_teams(1, 2, 3) // expected-error {{only one expression allowed in 'num_teams' clause}} +#pragma omp target teams distribute num_teams(1, 2, 3) // expected-error {{unexpected number of expressions in 'num_teams' clause}} for (int i=0; i<100; i++) foo(); -#pragma omp target teams distribute thread_limit(1, 2, 3) // expected-error {{only one expression allowed in 'thread_limit' clause}} +#pragma omp target teams distribute thread_limit(1, 2, 3) // expected-error {{unexpected number of expressions in 'thread_limit' clause}} for (int i=0; i<100; i++) foo(); -#pragma omp target teams ompx_bare num_teams(1, 2, 3, 4) thread_limit(1) // expected-error {{at most three expressions are allowed in 'num_teams' clause in 'target teams ompx_bare' construct}} +#pragma omp target teams ompx_bare num_teams(1, 2, 3, 4) thread_limit(1) // expected-error {{unexpected number of expressions in 'num_teams' clause}} for (int i=0; i<100; i++) foo(); -#pragma omp target teams ompx_bare num_teams(1) thread_limit(1, 2, 3, 4) // expected-error {{at most three expressions are allowed in 'thread_limit' clause in 'target teams ompx_bare' construct}} +#pragma omp target teams ompx_bare num_teams(1) thread_limit(1, 2, 3, 4) // expected-error {{unexpected number of expressions in 'thread_limit' clause}} for (int i=0; i<100; i++) foo(); return 0; @@ -97,16 +97,16 @@ int main(int argc, char **argv) { #pragma omp target teams distribute num_teams (3.14) // expected-error {{expression must have integral or unscoped enumeration type, not 'double'}} for (int i=0; i<100; i++) foo(); -#pragma omp target teams distribute num_teams(1, 2, 3) // expected-error {{only one expression allowed in 'num_teams' clause}} +#pragma omp target teams distribute num_teams(1, 2, 3) // expected-error {{unexpected number of expressions in 'num_teams' clause}} for (int i=0; i<100; i++) foo(); -#pragma omp target teams distribute thread_limit(1, 2, 3) // expected-error {{only one expression allowed in 'thread_limit' clause}} +#pragma omp target teams distribute thread_limit(1, 2, 3) // expected-error {{unexpected number of expressions in 'thread_limit' clause}} for (int i=0; i<100; i++) foo(); -#pragma omp target teams ompx_bare num_teams(1, 2, 3, 4) thread_limit(1) // expected-error {{at most three expressions are allowed in 'num_teams' clause in 'target teams ompx_bare' construct}} +#pragma omp target teams ompx_bare num_teams(1, 2, 3, 4) thread_limit(1) // expected-error {{unexpected number of expressions in 'num_teams' clause}} for (int i=0; i<100; i++) foo(); -#pragma omp target teams ompx_bare num_teams(1) thread_limit(1, 2, 3, 4) // expected-error {{at most three expressions are allowed in 'thread_limit' clause in 'target teams ompx_bare' construct}} +#pragma omp target teams ompx_bare num_teams(1) thread_limit(1, 2, 3, 4) // expected-error {{unexpected number of expressions in 'thread_limit' clause}} for (int i=0; i<100; i++) foo(); return tmain<int, 10>(argc); // expected-note {{in instantiation of function template specialization 'tmain<int, 10>' requested here}} diff --git a/clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp b/clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp index 092e0137d250d..fb9f4a7def289 100644 --- a/clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp +++ b/clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp @@ -43,9 +43,9 @@ T tmain(T argc) { for (int i=0; i<100; i++) foo(); #pragma omp target teams distribute parallel for num_teams(3.14) // expected-error 2 {{expression must have integral or unscoped enumeration type, not 'double'}} for (int i=0; i<100; i++) foo(); -#pragma omp target teams distribute parallel for num_teams(1, 2, 3) // expected-error {{only one expression allowed in 'num_teams' clause}} +#pragma omp target teams distribute parallel for num_teams(1, 2, 3) // expected-error {{unexpected number of expressions in 'num_teams' clause}} for (int i=0; i<100; i++) foo(); -#pragma omp target teams distribute parallel for thread_limit(1, 2, 3) // expected-error {{only one expression allowed in 'thread_limit' clause}} +#pragma omp target teams distribute parallel for thread_limit(1, 2, 3) // expected-error {{unexpected number of expressions in 'thread_limit' clause}} for (int i=0; i<100; i++) foo(); return 0; @@ -89,10 +89,10 @@ int main(int argc, char **argv) { #pragma omp target teams distribute parallel for num_teams (3.14) // expected-error {{expression must have integral or unscoped enumeration type, not 'double'}} for (int i=0; i<100; i++) foo(); -#pragma omp target teams distribute parallel for num_teams(1, 2, 3) // expected-error {{only one expression allowed in 'num_teams' clause}} +#pragma omp target teams distribute parallel for num_teams(1, 2, 3) // expected-error {{unexpected number of expressions in 'num_teams' clause}} for (int i=0; i<100; i++) foo(); -#pragma omp target teams distribute parallel for thread_limit(1, 2, 3) // expected-error {{only one expression allowed in 'thread_limit' clause}} +#pragma omp target teams distribute parallel for thread_limit(1, 2, 3) // expected-error {{unexpected number of expressions in 'thread_limit' clause}} for (int i=0; i<100; i++) foo(); return tmain<int, 10>(argc); // expected-note {{in instantiation of function template specialization 'tmain<int, 10>' requested here}} diff --git a/clang/test/OpenMP/teams_num_teams_messages.cpp b/clang/test/OpenMP/teams_num_teams_messages.cpp index 683a772295cdf..59c954eb364ba 100644 --- a/clang/test/OpenMP/teams_num_teams_messages.cpp +++ b/clang/test/OpenMP/teams_num_teams_messages.cpp @@ -61,10 +61,10 @@ T tmain(T argc) { #pragma omp teams num_teams(3.14) // expected-error 2 {{expression must have integral or unscoped enumeration type, not 'double'}} foo(); #pragma omp target -#pragma omp teams num_teams (1, 2, 3) // expected-error {{only one expression allowed in 'num_teams' clause}} +#pragma omp teams num_teams (1, 2, 3) // expected-error {{unexpected number of expressions in 'num_teams' clause}} foo(); #pragma omp target -#pragma omp teams thread_limit(1, 2, 3) // expected-error {{only one expression allowed in 'thread_limit' clause}} +#pragma omp teams thread_limit(1, 2, 3) // expected-error {{unexpected number of expressions in 'thread_limit' clause}} foo(); return 0; @@ -121,11 +121,11 @@ int main(int argc, char **argv) { foo(); #pragma omp target -#pragma omp teams num_teams (1, 2, 3) // expected-error {{only one expression allowed in 'num_teams' clause}} +#pragma omp teams num_teams (1, 2, 3) // expected-error {{unexpected number of expressions in 'num_teams' clause}} foo(); #pragma omp target -#pragma omp teams thread_limit(1, 2, 3) // expected-error {{only one expression allowed in 'thread_limit' clause}} +#pragma omp teams thread_limit(1, 2, 3) // expected-error {{unexpected number of expressions in 'thread_limit' clause}} foo(); return tmain<int, 10>(argc); // expected-note {{in instantiation of function template specialization 'tmain<int, 10>' requested here}} @@ -135,28 +135,28 @@ int main(int argc, char **argv) { void test_invalid_syntax() { int a = 1, b = 2, c = 3; - // expected-error@+1 {{only one expression allowed in 'num_teams' clause}} + // expected-error@+1 {{unexpected number of expressions in 'num_teams' clause}} #pragma omp teams num_teams(a, b, c) { } // expected-error@+1 {{lower bound is greater than upper bound in 'num_teams' clause}} #pragma omp teams num_teams(10:5) { } - // expected-error@+1 {{only one expression allowed in 'num_teams' clause}} + // expected-error@+1 {{unexpected number of expressions in 'num_teams' clause}} #pragma omp target teams num_teams(a, b, c) { } // expected-error@+1 {{lower bound is greater than upper bound in 'num_teams' clause}} #pragma omp target teams num_teams(8:3) { } - // expected-error@+1 {{only one expression allowed in 'num_teams' clause}} + // expected-error@+1 {{unexpected number of expressions in 'num_teams' clause}} #pragma omp target teams distribute num_teams(a, b, c) for (int i = 0; i < 100; ++i) { } // expected-error@+1 {{lower bound is greater than upper bound in 'num_teams' clause}} #pragma omp target teams distribute num_teams(15:7) for (int i = 0; i < 100; ++i) { } - // expected-error@+1 {{only one expression allowed in 'num_teams' clause}} + // expected-error@+1 {{unexpected number of expressions in 'num_teams' clause}} #pragma omp target teams distribute parallel for num_teams(a, b, c) for (int i = 0; i < 100; ++i) { } // expected-error@+1 {{lower bound is greater than upper bound in 'num_teams' clause}} @@ -164,14 +164,14 @@ void test_invalid_syntax() { for (int i = 0; i < 100; ++i) { } // Test target teams distribute parallel for simd directive - // expected-error@+1 {{only one expression allowed in 'num_teams' clause}} + // expected-error@+1 {{unexpected number of expressions in 'num_teams' clause}} #pragma omp target teams distribute parallel for simd num_teams(a, b, c) for (int i = 0; i < 100; ++i) { } // expected-error@+1 {{lower bound is greater than upper bound in 'num_teams' clause}} #pragma omp target teams distribute parallel for simd num_teams(20:6) for (int i = 0; i < 100; ++i) { } - // expected-error@+1 {{only one expression allowed in 'num_teams' clause}} + // expected-error@+1 {{unexpected number of expressions in 'num_teams' clause}} #pragma omp target teams distribute simd num_teams(a, b, c) for (int i = 0; i < 100; ++i) { } // expected-error@+1 {{lower bound is greater than upper bound in 'num_teams' clause}} diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp index 0dbf315166696..af9dc02656597 100644 --- a/clang/tools/libclang/CIndex.cpp +++ b/clang/tools/libclang/CIndex.cpp @@ -2533,6 +2533,8 @@ void OMPClauseEnqueue::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) { void OMPClauseEnqueue::VisitOMPThreadLimitClause( const OMPThreadLimitClause *C) { + if (const Expr *Modifier = C->getModifierExpr()) + Visitor->AddStmt(Modifier); VisitOMPClauseList(C); VisitOMPClauseWithPreInit(C); } >From bcf2f5caf622d73fe998c5e31cf4bf416dfd9655 Mon Sep 17 00:00:00 2001 From: Kevin Sala <[email protected]> Date: Wed, 1 Jul 2026 18:10:23 -0700 Subject: [PATCH 2/2] Fix review comments --- clang/include/clang/AST/OpenMPClause.h | 38 ++++++++++--------- .../clang/Basic/DiagnosticSemaKinds.td | 6 +-- clang/include/clang/Sema/SemaOpenMP.h | 9 +++-- clang/lib/Parse/ParseOpenMP.cpp | 2 +- clang/lib/Sema/SemaOpenMP.cpp | 27 +++++++++---- clang/test/OpenMP/dims_modifier_messages.cpp | 5 +++ ...et_teams_distribute_num_teams_messages.cpp | 16 ++++---- ...ribute_parallel_for_num_teams_messages.cpp | 8 ++-- .../test/OpenMP/teams_num_teams_messages.cpp | 20 +++++----- 9 files changed, 75 insertions(+), 56 deletions(-) diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h index bd58a41d4df1e..641c03bf8ff5c 100644 --- a/clang/include/clang/AST/OpenMPClause.h +++ b/clang/include/clang/AST/OpenMPClause.h @@ -6999,6 +6999,7 @@ class OMPNumTeamsClause final : public OMPVarListClause<OMPNumTeamsClause>, public OMPClauseWithPreInit, private llvm::TrailingObjects<OMPNumTeamsClause, Expr *> { + friend class OMPClauseReader; friend OMPVarListClause; friend TrailingObjects; @@ -7020,6 +7021,15 @@ class OMPNumTeamsClause final SourceLocation(), SourceLocation(), N), OMPClauseWithPreInit(this) {} + /// Set the modifier. + void setModifier(OpenMPNumTeamsClauseModifier M) { Modifier = M; } + + /// Set the expression of the modifier. + void setModifierExpr(Expr *E) { *varlist_end() = E; } + + /// Set the location of the modifier. + void setModifierLoc(SourceLocation Loc) { ModifierLoc = Loc; } + public: /// Creates clause with a list of variables \a VL. /// @@ -7056,18 +7066,12 @@ class OMPNumTeamsClause final /// Get the modifier. OpenMPNumTeamsClauseModifier getModifier() const { return Modifier; } - /// Set the modifier. - void setModifier(OpenMPNumTeamsClauseModifier M) { Modifier = M; } - /// Get the expression of the modifier. const Expr *getModifierExpr() const { return *varlist_end(); } /// Get the expression of the modifier. Expr *getModifierExpr() { return *varlist_end(); } - /// Set the expression of the modifier. - void setModifierExpr(Expr *E) { *varlist_end() = E; } - /// Get the expression of the modifier if it is the dims modifier. const Expr *getDimsModifierExpr() const { if (Modifier == OMPC_NUMTEAMS_dims) @@ -7078,9 +7082,6 @@ class OMPNumTeamsClause final /// Get the location of the modifier. SourceLocation getModifierLoc() const { return ModifierLoc; } - /// Set the location of the modifier. - void setModifierLoc(SourceLocation Loc) { ModifierLoc = Loc; } - child_range children() { return child_range(reinterpret_cast<Stmt **>(varlist_begin()), reinterpret_cast<Stmt **>(varlist_end()) + 1); @@ -7128,6 +7129,7 @@ class OMPThreadLimitClause final : public OMPVarListClause<OMPThreadLimitClause>, public OMPClauseWithPreInit, private llvm::TrailingObjects<OMPThreadLimitClause, Expr *> { + friend class OMPClauseReader; friend OMPVarListClause; friend TrailingObjects; @@ -7150,6 +7152,15 @@ class OMPThreadLimitClause final SourceLocation(), SourceLocation(), N), OMPClauseWithPreInit(this) {} + /// Set the modifier. + void setModifier(OpenMPThreadLimitClauseModifier M) { Modifier = M; } + + /// Set the location of the modifier. + void setModifierLoc(SourceLocation Loc) { ModifierLoc = Loc; } + + /// Set the expression of the modifier. + void setModifierExpr(Expr *E) { *varlist_end() = E; } + public: /// Creates clause with a list of variables \a VL. /// @@ -7186,18 +7197,12 @@ class OMPThreadLimitClause final /// Get the modifier. OpenMPThreadLimitClauseModifier getModifier() const { return Modifier; } - /// Set the modifier. - void setModifier(OpenMPThreadLimitClauseModifier M) { Modifier = M; } - /// Get the expression of the modifier. const Expr *getModifierExpr() const { return *varlist_end(); } /// Get the expression of the modifier. Expr *getModifierExpr() { return *varlist_end(); } - /// Set the expression of the modifier. - void setModifierExpr(Expr *E) { *varlist_end() = E; } - /// Get the expression of the modifier if it is the dims modifier. const Expr *getDimsModifierExpr() const { if (Modifier == OMPC_THREADLIMIT_dims) @@ -7208,9 +7213,6 @@ class OMPThreadLimitClause final /// Get the location of the modifier. SourceLocation getModifierLoc() const { return ModifierLoc; } - /// Set the location of the modifier. - void setModifierLoc(SourceLocation Loc) { ModifierLoc = Loc; } - child_range children() { return child_range(reinterpret_cast<Stmt **>(varlist_begin()), reinterpret_cast<Stmt **>(varlist_end()) + 1); diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td index 22dcaed45e598..c76f2f5ed6793 100644 --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -12793,9 +12793,9 @@ def warn_omp_unterminated_declare_target : Warning< InGroup<SourceUsesOpenMP>; def err_ompx_bare_no_grid : Error< "'ompx_bare' clause requires explicit grid size via 'num_teams' and 'thread_limit' clauses">; +def err_ompx_more_than_three_expr_not_allowed : Error<"at most three expressions are allowed in '%0' clause in 'target teams ompx_bare' construct">; def err_ompx_bare_no_dims : Error< "'ompx_bare' clause cannot be specified with 'dims' modifier in 'num_teams' and 'thread_limit' clauses">; -def err_omp_multi_expr_not_allowed : Error<"only one expression allowed in '%0' clause">; def err_omp_max_three_exprs: Error<"maximum three expressions are supported in '%0' clause">; def err_omp_transparent_invalid_value : Error<"invalid value for transparent clause," " expected one of: omp_not_impex, omp_import, omp_export, omp_impex">; @@ -12806,9 +12806,7 @@ def err_omp_num_teams_lower_bound_larger def err_omp_modifier_requires_version : Error< "'%0' modifier in '%1' clause requires OpenMP %2 or later">; def err_omp_unexpected_num_exprs - : Error<"unexpected number of expressions in '%0' clause">; -def err_omp_max_supported_dims - : Error<"at most three dimensions are supported in '%0' clause">; + : Error<"unexpected number of expressions in '%0' clause (expected %1, have %2)">; def err_omp_incompatible_modifiers : Error<"'%0' modifier cannot be specified with '%1' modifier in '%2' clause">; } // end of OpenMP category diff --git a/clang/include/clang/Sema/SemaOpenMP.h b/clang/include/clang/Sema/SemaOpenMP.h index 726ac8539c66b..9ef388a744db8 100644 --- a/clang/include/clang/Sema/SemaOpenMP.h +++ b/clang/include/clang/Sema/SemaOpenMP.h @@ -31,6 +31,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/Frontend/OpenMP/OMP.h.inc" #include "llvm/Frontend/OpenMP/OMPConstants.h" +#include <array> #include <optional> #include <string> #include <utility> @@ -1188,10 +1189,10 @@ class SemaOpenMP : public SemaBase { SourceLocation RLoc; CXXScopeSpec ReductionOrMapperIdScopeSpec; DeclarationNameInfo ReductionOrMapperId; - SmallVector<int, 2> ExtraModifierArray = {-1, -1}; - SmallVector<Expr *, 2> ExtraModifierExprArray = {nullptr, nullptr}; - SmallVector<SourceLocation, 2> ExtraModifierLocArray = {SourceLocation(), - SourceLocation()}; + std::array<int, 2> ExtraModifierArray = {-1, -1}; + std::array<Expr *, 2> ExtraModifierExprArray = {nullptr, nullptr}; + std::array<SourceLocation, 2> ExtraModifierLocArray = {SourceLocation(), + SourceLocation()}; /// Additional modifier for linear, map, depend, lastprivate, /// use_device_ptr, or num_teams clause. int &ExtraModifier = ExtraModifierArray[0]; diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp index 2206d153badfd..7ef1e4211a1ef 100644 --- a/clang/lib/Parse/ParseOpenMP.cpp +++ b/clang/lib/Parse/ParseOpenMP.cpp @@ -5299,7 +5299,7 @@ bool Parser::ParseOpenMPVarList(OpenMPDirectiveKind DKind, // Lower-bound modifier is only accepted in num_teams. bool CanParseLowerBoundModifier = (Kind == OMPC_num_teams); - if (!Tok.isAnnotation() && PP.getSpelling(Tok) == "dims" && + if (Tok.is(tok::identifier) && Tok.getIdentifierInfo()->isStr("dims") && NextToken().is(tok::l_paren)) { SourceLocation TLoc = Tok.getLocation(); ConsumeToken(); diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp index 99925ecbe1acc..caf79229739f0 100644 --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -13627,22 +13627,35 @@ static bool checkClauseNumExprs(SemaBase &SemaRef, const ClauseT *Clause, if (!Clause) return false; - uint64_t MaxExprs = BareClause ? 3 : 1; + const uint64_t NumVars = Clause->getVarRefs().size(); + if (BareClause) { + // The ompx_bare clause allows up to three expresions. + if (NumVars > 3) { + SemaRef.Diag(Clause->getBeginLoc(), + diag::err_ompx_more_than_three_expr_not_allowed) + << getOpenMPClauseName(Clause->getClauseKind()); + return true; + } + return false; + } + + // By default, only one expression accepted. + uint64_t MaxExprs = 1; const Expr *DimsExpr = Clause->getDimsModifierExpr(); if (DimsExpr) { - // Cannot verify the size yet. + // Cannot verify the expected size yet. if (DimsExpr->isInstantiationDependent()) return false; + // The dims modifier determines the exact number of expressions. MaxExprs = DimsExpr->EvaluateKnownConstInt(SemaRef.getASTContext()).getExtValue(); } - size_t NumVars = Clause->getVarRefs().size(); - if (NumVars > MaxExprs) { + if (NumVars != MaxExprs) { SemaRef.Diag(Clause->getBeginLoc(), diag::err_omp_unexpected_num_exprs) - << getOpenMPClauseName(Clause->getClauseKind()); + << getOpenMPClauseName(Clause->getClauseKind()) << MaxExprs << NumVars; return true; } if (NumVars > 3) { @@ -19352,7 +19365,7 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind, break; case OMPC_thread_limit: assert(0 <= ExtraModifier && ExtraModifier <= OMPC_THREADLIMIT_unknown && - "Unexpected num_teams modifier."); + "Unexpected thread_limit modifier."); Res = ActOnOpenMPThreadLimitClause( VarList, static_cast<OpenMPThreadLimitClauseModifier>(ExtraModifier), ExtraModifierExpr, ExtraModifierLoc, StartLoc, LParenLoc, EndLoc); @@ -24450,7 +24463,7 @@ ExprResult SemaOpenMP::ActOnOpenMPDimsModifier(OpenMPClauseKind ClauseKind, return DimsRes; Diag(VarListEndLoc, diag::err_omp_unexpected_num_exprs) - << getOpenMPClauseName(ClauseKind); + << getOpenMPClauseName(ClauseKind) << NumDims << VarList.size(); return ExprError(); } diff --git a/clang/test/OpenMP/dims_modifier_messages.cpp b/clang/test/OpenMP/dims_modifier_messages.cpp index 680bb1d6d8d13..ce0dd7eb79c52 100644 --- a/clang/test/OpenMP/dims_modifier_messages.cpp +++ b/clang/test/OpenMP/dims_modifier_messages.cpp @@ -38,6 +38,11 @@ void bar(int N) { // expected-note {{declared here}} #pragma omp target teams distribute num_teams(dims(): 4) // expected-error {{expected expression}} for (int i = 0; i < 10; ++i) {} + // 3. Incompatible modifiers. + +#pragma omp target teams num_teams(dims(1),10:20) // expected-error {{'lower_bound' modifier cannot be specified with 'dims' modifier in 'num_teams' clause}} + foo(); + // 2. Mismatching number of expressions. #pragma omp target teams num_teams(dims(2): 4) // expected-error {{unexpected number of expressions in 'num_teams' clause}} diff --git a/clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp b/clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp index e353fc81d4eca..0fc2f3ff54017 100644 --- a/clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp +++ b/clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp @@ -44,16 +44,16 @@ T tmain(T argc) { #pragma omp target teams distribute num_teams(3.14) // expected-error 2 {{expression must have integral or unscoped enumeration type, not 'double'}} for (int i=0; i<100; i++) foo(); -#pragma omp target teams distribute num_teams(1, 2, 3) // expected-error {{unexpected number of expressions in 'num_teams' clause}} +#pragma omp target teams distribute num_teams(1, 2, 3) // expected-error {{unexpected number of expressions in 'num_teams' clause (expected 1, have 3)}} for (int i=0; i<100; i++) foo(); -#pragma omp target teams distribute thread_limit(1, 2, 3) // expected-error {{unexpected number of expressions in 'thread_limit' clause}} +#pragma omp target teams distribute thread_limit(1, 2, 3) // expected-error {{unexpected number of expressions in 'thread_limit' clause (expected 1, have 3)}} for (int i=0; i<100; i++) foo(); -#pragma omp target teams ompx_bare num_teams(1, 2, 3, 4) thread_limit(1) // expected-error {{unexpected number of expressions in 'num_teams' clause}} +#pragma omp target teams ompx_bare num_teams(1, 2, 3, 4) thread_limit(1) // expected-error {{at most three expressions are allowed in 'num_teams' clause in 'target teams ompx_bare' construct}} for (int i=0; i<100; i++) foo(); -#pragma omp target teams ompx_bare num_teams(1) thread_limit(1, 2, 3, 4) // expected-error {{unexpected number of expressions in 'thread_limit' clause}} +#pragma omp target teams ompx_bare num_teams(1) thread_limit(1, 2, 3, 4) // expected-error {{at most three expressions are allowed in 'thread_limit' clause in 'target teams ompx_bare' construct}} for (int i=0; i<100; i++) foo(); return 0; @@ -97,16 +97,16 @@ int main(int argc, char **argv) { #pragma omp target teams distribute num_teams (3.14) // expected-error {{expression must have integral or unscoped enumeration type, not 'double'}} for (int i=0; i<100; i++) foo(); -#pragma omp target teams distribute num_teams(1, 2, 3) // expected-error {{unexpected number of expressions in 'num_teams' clause}} +#pragma omp target teams distribute num_teams(1, 2, 3) // expected-error {{unexpected number of expressions in 'num_teams' clause (expected 1, have 3)}} for (int i=0; i<100; i++) foo(); -#pragma omp target teams distribute thread_limit(1, 2, 3) // expected-error {{unexpected number of expressions in 'thread_limit' clause}} +#pragma omp target teams distribute thread_limit(1, 2, 3) // expected-error {{unexpected number of expressions in 'thread_limit' clause (expected 1, have 3)}} for (int i=0; i<100; i++) foo(); -#pragma omp target teams ompx_bare num_teams(1, 2, 3, 4) thread_limit(1) // expected-error {{unexpected number of expressions in 'num_teams' clause}} +#pragma omp target teams ompx_bare num_teams(1, 2, 3, 4) thread_limit(1) // expected-error {{at most three expressions are allowed in 'num_teams' clause in 'target teams ompx_bare' construct}} for (int i=0; i<100; i++) foo(); -#pragma omp target teams ompx_bare num_teams(1) thread_limit(1, 2, 3, 4) // expected-error {{unexpected number of expressions in 'thread_limit' clause}} +#pragma omp target teams ompx_bare num_teams(1) thread_limit(1, 2, 3, 4) // expected-error {{at most three expressions are allowed in 'thread_limit' clause in 'target teams ompx_bare' construct}} for (int i=0; i<100; i++) foo(); return tmain<int, 10>(argc); // expected-note {{in instantiation of function template specialization 'tmain<int, 10>' requested here}} diff --git a/clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp b/clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp index fb9f4a7def289..4d72a2c23bb28 100644 --- a/clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp +++ b/clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp @@ -43,9 +43,9 @@ T tmain(T argc) { for (int i=0; i<100; i++) foo(); #pragma omp target teams distribute parallel for num_teams(3.14) // expected-error 2 {{expression must have integral or unscoped enumeration type, not 'double'}} for (int i=0; i<100; i++) foo(); -#pragma omp target teams distribute parallel for num_teams(1, 2, 3) // expected-error {{unexpected number of expressions in 'num_teams' clause}} +#pragma omp target teams distribute parallel for num_teams(1, 2, 3) // expected-error {{unexpected number of expressions in 'num_teams' clause (expected 1, have 3)}} for (int i=0; i<100; i++) foo(); -#pragma omp target teams distribute parallel for thread_limit(1, 2, 3) // expected-error {{unexpected number of expressions in 'thread_limit' clause}} +#pragma omp target teams distribute parallel for thread_limit(1, 2, 3) // expected-error {{unexpected number of expressions in 'thread_limit' clause (expected 1, have 3)}} for (int i=0; i<100; i++) foo(); return 0; @@ -89,10 +89,10 @@ int main(int argc, char **argv) { #pragma omp target teams distribute parallel for num_teams (3.14) // expected-error {{expression must have integral or unscoped enumeration type, not 'double'}} for (int i=0; i<100; i++) foo(); -#pragma omp target teams distribute parallel for num_teams(1, 2, 3) // expected-error {{unexpected number of expressions in 'num_teams' clause}} +#pragma omp target teams distribute parallel for num_teams(1, 2, 3) // expected-error {{unexpected number of expressions in 'num_teams' clause (expected 1, have 3)}} for (int i=0; i<100; i++) foo(); -#pragma omp target teams distribute parallel for thread_limit(1, 2, 3) // expected-error {{unexpected number of expressions in 'thread_limit' clause}} +#pragma omp target teams distribute parallel for thread_limit(1, 2, 3) // expected-error {{unexpected number of expressions in 'thread_limit' clause (expected 1, have 3)}} for (int i=0; i<100; i++) foo(); return tmain<int, 10>(argc); // expected-note {{in instantiation of function template specialization 'tmain<int, 10>' requested here}} diff --git a/clang/test/OpenMP/teams_num_teams_messages.cpp b/clang/test/OpenMP/teams_num_teams_messages.cpp index 59c954eb364ba..92c85ac3d710b 100644 --- a/clang/test/OpenMP/teams_num_teams_messages.cpp +++ b/clang/test/OpenMP/teams_num_teams_messages.cpp @@ -61,10 +61,10 @@ T tmain(T argc) { #pragma omp teams num_teams(3.14) // expected-error 2 {{expression must have integral or unscoped enumeration type, not 'double'}} foo(); #pragma omp target -#pragma omp teams num_teams (1, 2, 3) // expected-error {{unexpected number of expressions in 'num_teams' clause}} +#pragma omp teams num_teams (1, 2, 3) // expected-error {{unexpected number of expressions in 'num_teams' clause (expected 1, have 3)}} foo(); #pragma omp target -#pragma omp teams thread_limit(1, 2, 3) // expected-error {{unexpected number of expressions in 'thread_limit' clause}} +#pragma omp teams thread_limit(1, 2, 3) // expected-error {{unexpected number of expressions in 'thread_limit' clause (expected 1, have 3)}} foo(); return 0; @@ -121,11 +121,11 @@ int main(int argc, char **argv) { foo(); #pragma omp target -#pragma omp teams num_teams (1, 2, 3) // expected-error {{unexpected number of expressions in 'num_teams' clause}} +#pragma omp teams num_teams (1, 2, 3) // expected-error {{unexpected number of expressions in 'num_teams' clause (expected 1, have 3)}} foo(); #pragma omp target -#pragma omp teams thread_limit(1, 2, 3) // expected-error {{unexpected number of expressions in 'thread_limit' clause}} +#pragma omp teams thread_limit(1, 2, 3) // expected-error {{unexpected number of expressions in 'thread_limit' clause (expected 1, have 3)}} foo(); return tmain<int, 10>(argc); // expected-note {{in instantiation of function template specialization 'tmain<int, 10>' requested here}} @@ -135,28 +135,28 @@ int main(int argc, char **argv) { void test_invalid_syntax() { int a = 1, b = 2, c = 3; - // expected-error@+1 {{unexpected number of expressions in 'num_teams' clause}} + // expected-error@+1 {{unexpected number of expressions in 'num_teams' clause (expected 1, have 3)}} #pragma omp teams num_teams(a, b, c) { } // expected-error@+1 {{lower bound is greater than upper bound in 'num_teams' clause}} #pragma omp teams num_teams(10:5) { } - // expected-error@+1 {{unexpected number of expressions in 'num_teams' clause}} + // expected-error@+1 {{unexpected number of expressions in 'num_teams' clause (expected 1, have 3)}} #pragma omp target teams num_teams(a, b, c) { } // expected-error@+1 {{lower bound is greater than upper bound in 'num_teams' clause}} #pragma omp target teams num_teams(8:3) { } - // expected-error@+1 {{unexpected number of expressions in 'num_teams' clause}} + // expected-error@+1 {{unexpected number of expressions in 'num_teams' clause (expected 1, have 3)}} #pragma omp target teams distribute num_teams(a, b, c) for (int i = 0; i < 100; ++i) { } // expected-error@+1 {{lower bound is greater than upper bound in 'num_teams' clause}} #pragma omp target teams distribute num_teams(15:7) for (int i = 0; i < 100; ++i) { } - // expected-error@+1 {{unexpected number of expressions in 'num_teams' clause}} + // expected-error@+1 {{unexpected number of expressions in 'num_teams' clause (expected 1, have 3)}} #pragma omp target teams distribute parallel for num_teams(a, b, c) for (int i = 0; i < 100; ++i) { } // expected-error@+1 {{lower bound is greater than upper bound in 'num_teams' clause}} @@ -164,14 +164,14 @@ void test_invalid_syntax() { for (int i = 0; i < 100; ++i) { } // Test target teams distribute parallel for simd directive - // expected-error@+1 {{unexpected number of expressions in 'num_teams' clause}} + // expected-error@+1 {{unexpected number of expressions in 'num_teams' clause (expected 1, have 3)}} #pragma omp target teams distribute parallel for simd num_teams(a, b, c) for (int i = 0; i < 100; ++i) { } // expected-error@+1 {{lower bound is greater than upper bound in 'num_teams' clause}} #pragma omp target teams distribute parallel for simd num_teams(20:6) for (int i = 0; i < 100; ++i) { } - // expected-error@+1 {{unexpected number of expressions in 'num_teams' clause}} + // expected-error@+1 {{unexpected number of expressions in 'num_teams' clause (expected 1, have 3)}} #pragma omp target teams distribute simd num_teams(a, b, c) for (int i = 0; i < 100; ++i) { } // expected-error@+1 {{lower bound is greater than upper bound in 'num_teams' clause}} _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
