llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clang-codegen Author: Amit Tiwari (amitamd7) <details> <summary>Changes</summary> Implement Loop-splitting #pragma omp split construct with counts clause. Posting this PR after the revert of PR ([#<!-- -->183261](https://github.com/llvm/llvm-project/pull/183261)) --- Patch is 218.00 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/190397.diff 80 Files Affected: - (modified) clang/bindings/python/clang/cindex.py (+3) - (modified) clang/include/clang-c/Index.h (+4) - (modified) clang/include/clang/AST/OpenMPClause.h (+101) - (modified) clang/include/clang/AST/RecursiveASTVisitor.h (+10) - (modified) clang/include/clang/AST/StmtOpenMP.h (+78) - (modified) clang/include/clang/ASTMatchers/ASTMatchers.h (+20) - (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+2) - (modified) clang/include/clang/Basic/StmtNodes.td (+1) - (modified) clang/include/clang/Parse/Parser.h (+3) - (modified) clang/include/clang/Sema/SemaOpenMP.h (+12) - (modified) clang/include/clang/Serialization/ASTBitCodes.h (+1) - (modified) clang/lib/AST/OpenMPClause.cpp (+35) - (modified) clang/lib/AST/StmtOpenMP.cpp (+21) - (modified) clang/lib/AST/StmtPrinter.cpp (+5) - (modified) clang/lib/AST/StmtProfile.cpp (+10) - (modified) clang/lib/ASTMatchers/ASTMatchersInternal.cpp (+4) - (modified) clang/lib/ASTMatchers/Dynamic/Registry.cpp (+2) - (modified) clang/lib/Basic/OpenMPKinds.cpp (+4-1) - (modified) clang/lib/CodeGen/CGStmt.cpp (+3) - (modified) clang/lib/CodeGen/CGStmtOpenMP.cpp (+8) - (modified) clang/lib/CodeGen/CodeGenFunction.h (+1) - (modified) clang/lib/Parse/ParseOpenMP.cpp (+59) - (modified) clang/lib/Sema/SemaExceptionSpec.cpp (+1) - (modified) clang/lib/Sema/SemaOpenMP.cpp (+271) - (modified) clang/lib/Sema/TreeTransform.h (+44) - (modified) clang/lib/Serialization/ASTReader.cpp (+15) - (modified) clang/lib/Serialization/ASTReaderStmt.cpp (+11) - (modified) clang/lib/Serialization/ASTWriter.cpp (+11) - (modified) clang/lib/Serialization/ASTWriterStmt.cpp (+5) - (modified) clang/lib/StaticAnalyzer/Core/ExprEngine.cpp (+1) - (added) clang/test/AST/ast-dump-openmp-split.c (+19) - (added) clang/test/Index/openmp-split.c (+11) - (added) clang/test/OpenMP/split_analyze.c (+10) - (added) clang/test/OpenMP/split_ast_print.cpp (+71) - (added) clang/test/OpenMP/split_codegen.cpp (+1986) - (added) clang/test/OpenMP/split_composition.cpp (+17) - (added) clang/test/OpenMP/split_compound_associated.cpp (+13) - (added) clang/test/OpenMP/split_counts_constexpr.cpp (+19) - (added) clang/test/OpenMP/split_counts_ice.c (+56) - (added) clang/test/OpenMP/split_counts_verify.c (+123) - (added) clang/test/OpenMP/split_diag_errors.c (+61) - (added) clang/test/OpenMP/split_distribute_inner_split.cpp (+14) - (added) clang/test/OpenMP/split_driver_smoke.c (+12) - (added) clang/test/OpenMP/split_iv_types.c (+24) - (added) clang/test/OpenMP/split_loop_styles.cpp (+14) - (added) clang/test/OpenMP/split_member_ctor.cpp (+20) - (added) clang/test/OpenMP/split_messages.cpp (+108) - (added) clang/test/OpenMP/split_nested_outer_only.c (+12) - (added) clang/test/OpenMP/split_offload_codegen.cpp (+27) - (added) clang/test/OpenMP/split_omp_fill.c (+36) - (added) clang/test/OpenMP/split_openmp_version.cpp (+22) - (added) clang/test/OpenMP/split_opts_simd_debug.cpp (+30) - (added) clang/test/OpenMP/split_parallel_split.cpp (+15) - (added) clang/test/OpenMP/split_pch_codegen.cpp (+43) - (added) clang/test/OpenMP/split_range_for_diag.cpp (+25) - (added) clang/test/OpenMP/split_serialize_module.cpp (+24) - (added) clang/test/OpenMP/split_teams_nesting.cpp (+13) - (added) clang/test/OpenMP/split_template_nttp.cpp (+15) - (added) clang/test/OpenMP/split_templates.cpp (+30) - (added) clang/test/OpenMP/split_trip_volatile.c (+14) - (modified) clang/tools/libclang/CIndex.cpp (+7) - (modified) clang/tools/libclang/CXCursor.cpp (+3) - (modified) clang/unittests/ASTMatchers/ASTMatchersNodeTest.cpp (+62) - (modified) clang/unittests/ASTMatchers/ASTMatchersTest.h (+14) - (modified) llvm/include/llvm/Frontend/OpenMP/OMP.td (+11-10) - (added) openmp/runtime/test/transform/split/fill_first.c (+23) - (added) openmp/runtime/test/transform/split/foreach.cpp (+24) - (added) openmp/runtime/test/transform/split/intfor.c (+26) - (added) openmp/runtime/test/transform/split/intfor_negstart.c (+27) - (added) openmp/runtime/test/transform/split/iterfor.cpp (+139) - (added) openmp/runtime/test/transform/split/leq_bound.c (+22) - (added) openmp/runtime/test/transform/split/lit.local.cfg (+5) - (added) openmp/runtime/test/transform/split/negative_incr.c (+22) - (added) openmp/runtime/test/transform/split/nonconstant_incr.c (+22) - (added) openmp/runtime/test/transform/split/parallel-split-intfor.c (+27) - (added) openmp/runtime/test/transform/split/single_fill.c (+23) - (added) openmp/runtime/test/transform/split/three_segments.c (+26) - (added) openmp/runtime/test/transform/split/trip_one.c (+32) - (added) openmp/runtime/test/transform/split/unsigned_iv.c (+24) - (added) openmp/runtime/test/transform/split/zero_first_segment.c (+21) ``````````diff diff --git a/clang/bindings/python/clang/cindex.py b/clang/bindings/python/clang/cindex.py index b71f9ed2275e0..a90d48cf6d481 100644 --- a/clang/bindings/python/clang/cindex.py +++ b/clang/bindings/python/clang/cindex.py @@ -1453,6 +1453,9 @@ def is_unexposed(self): # OpenMP fuse directive. OMP_FUSE_DIRECTIVE = 311 + # OpenMP split directive. + OMP_SPLIT_DIRECTIVE = 312 + # OpenACC Compute Construct. OPEN_ACC_COMPUTE_DIRECTIVE = 320 diff --git a/clang/include/clang-c/Index.h b/clang/include/clang-c/Index.h index dcf1f4f1b4258..119bd68ff9814 100644 --- a/clang/include/clang-c/Index.h +++ b/clang/include/clang-c/Index.h @@ -2166,6 +2166,10 @@ enum CXCursorKind { */ CXCursor_OMPFuseDirective = 311, + /** OpenMP split directive. + */ + CXCursor_OMPSplitDirective = 312, + /** OpenACC Compute Construct. */ CXCursor_OpenACCComputeConstruct = 320, diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h index af5d3f4698eda..ccf2c40bc5efa 100644 --- a/clang/include/clang/AST/OpenMPClause.h +++ b/clang/include/clang/AST/OpenMPClause.h @@ -39,6 +39,7 @@ #include "llvm/Support/Compiler.h" #include "llvm/Support/TrailingObjects.h" #include <cassert> +#include <climits> #include <cstddef> #include <iterator> #include <utility> @@ -1023,6 +1024,106 @@ class OMPSizesClause final } }; +/// This represents the 'counts' clause in the '#pragma omp split' directive. +/// +/// \code +/// #pragma omp split counts(3, omp_fill, 2) +/// for (int i = 0; i < n; ++i) { ... } +/// \endcode +class OMPCountsClause final + : public OMPClause, + private llvm::TrailingObjects<OMPCountsClause, Expr *> { + friend class OMPClauseReader; + friend class llvm::TrailingObjects<OMPCountsClause, Expr *>; + + /// Location of '('. + SourceLocation LParenLoc; + + /// Number of count expressions in the clause. + unsigned NumCounts = 0; + + /// 0-based index of the omp_fill list item. + std::optional<unsigned> OmpFillIndex; + + /// Source location of the omp_fill keyword. + SourceLocation OmpFillLoc; + + /// Build an empty clause. + explicit OMPCountsClause(int NumCounts) + : OMPClause(llvm::omp::OMPC_counts, SourceLocation(), SourceLocation()), + NumCounts(NumCounts) {} + + /// Sets the location of '('. + void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; } + void setOmpFillIndex(std::optional<unsigned> Idx) { OmpFillIndex = Idx; } + void setOmpFillLoc(SourceLocation Loc) { OmpFillLoc = Loc; } + + /// Sets the count expressions. + void setCountsRefs(ArrayRef<Expr *> VL) { + assert(VL.size() == NumCounts); + llvm::copy(VL, getCountsRefs().begin()); + } + +public: + /// Build a 'counts' AST node. + /// + /// \param C Context of the AST. + /// \param StartLoc Location of the 'counts' identifier. + /// \param LParenLoc Location of '('. + /// \param EndLoc Location of ')'. + /// \param Counts Content of the clause. + static OMPCountsClause *Create(const ASTContext &C, SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc, ArrayRef<Expr *> Counts, + std::optional<unsigned> FillIdx, + SourceLocation FillLoc); + + /// Build an empty 'counts' AST node for deserialization. + /// + /// \param C Context of the AST. + /// \param NumCounts Number of items in the clause. + static OMPCountsClause *CreateEmpty(const ASTContext &C, unsigned NumCounts); + + /// Returns the location of '('. + SourceLocation getLParenLoc() const { return LParenLoc; } + + /// Returns the number of list items. + unsigned getNumCounts() const { return NumCounts; } + + std::optional<unsigned> getOmpFillIndex() const { return OmpFillIndex; } + SourceLocation getOmpFillLoc() const { return OmpFillLoc; } + bool hasOmpFill() const { return OmpFillIndex.has_value(); } + + /// Returns the count expressions. + MutableArrayRef<Expr *> getCountsRefs() { + return getTrailingObjects(NumCounts); + } + ArrayRef<Expr *> getCountsRefs() const { + return getTrailingObjects(NumCounts); + } + + child_range children() { + MutableArrayRef<Expr *> Counts = getCountsRefs(); + return child_range(reinterpret_cast<Stmt **>(Counts.begin()), + reinterpret_cast<Stmt **>(Counts.end())); + } + const_child_range children() const { + ArrayRef<Expr *> Counts = getCountsRefs(); + return const_child_range(reinterpret_cast<Stmt *const *>(Counts.begin()), + reinterpret_cast<Stmt *const *>(Counts.end())); + } + child_range used_children() { + return child_range(child_iterator(), child_iterator()); + } + const_child_range used_children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + + static bool classof(const OMPClause *T) { + return T->getClauseKind() == llvm::omp::OMPC_counts; + } +}; + /// This class represents the 'permutation' clause in the /// '#pragma omp interchange' directive. /// diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h index ce6ad723191e0..1a14dd2c666b5 100644 --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -3202,6 +3202,9 @@ DEF_TRAVERSE_STMT(OMPFuseDirective, DEF_TRAVERSE_STMT(OMPInterchangeDirective, { TRY_TO(TraverseOMPExecutableDirective(S)); }) +DEF_TRAVERSE_STMT(OMPSplitDirective, + { TRY_TO(TraverseOMPExecutableDirective(S)); }) + DEF_TRAVERSE_STMT(OMPForDirective, { TRY_TO(TraverseOMPExecutableDirective(S)); }) @@ -3503,6 +3506,13 @@ bool RecursiveASTVisitor<Derived>::VisitOMPSizesClause(OMPSizesClause *C) { return true; } +template <typename Derived> +bool RecursiveASTVisitor<Derived>::VisitOMPCountsClause(OMPCountsClause *C) { + for (Expr *E : C->getCountsRefs()) + TRY_TO(TraverseStmt(E)); + return true; +} + template <typename Derived> bool RecursiveASTVisitor<Derived>::VisitOMPPermutationClause( OMPPermutationClause *C) { diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h index bc6aeaa8d143c..dbc76e7df8ecd 100644 --- a/clang/include/clang/AST/StmtOpenMP.h +++ b/clang/include/clang/AST/StmtOpenMP.h @@ -6065,6 +6065,84 @@ class OMPFuseDirective final } }; +/// Represents the '#pragma omp split' loop transformation directive. +/// +/// \code{.c} +/// #pragma omp split counts(3, omp_fill, 2) +/// for (int i = 0; i < n; ++i) +/// ... +/// \endcode +/// +/// This directive transforms a single loop into multiple loops based on +/// index ranges. The transformation splits the iteration space of the loop +/// into multiple contiguous ranges. The \c counts clause is required and +/// exactly one list item must be \c omp_fill. +class OMPSplitDirective final + : public OMPCanonicalLoopNestTransformationDirective { + friend class ASTStmtReader; + friend class OMPExecutableDirective; + + /// Offsets of child members. + enum { + PreInitsOffset = 0, + TransformedStmtOffset, + }; + + explicit OMPSplitDirective(SourceLocation StartLoc, SourceLocation EndLoc, + unsigned NumLoops) + : OMPCanonicalLoopNestTransformationDirective( + OMPSplitDirectiveClass, llvm::omp::OMPD_split, StartLoc, EndLoc, + NumLoops) {} + + void setPreInits(Stmt *PreInits) { + Data->getChildren()[PreInitsOffset] = PreInits; + } + + void setTransformedStmt(Stmt *S) { + Data->getChildren()[TransformedStmtOffset] = S; + } + +public: + /// Create a new AST node representation for '#pragma omp split'. + /// + /// \param C Context of the AST. + /// \param StartLoc Location of the introducer (e.g. the 'omp' token). + /// \param EndLoc Location of the directive's end (e.g. the tok::eod). + /// \param Clauses The directive's clauses (e.g. the required \c counts + /// clause). + /// \param NumLoops Number of affected loops (should be 1 for split). + /// \param AssociatedStmt The outermost associated loop. + /// \param TransformedStmt The loop nest after splitting, or nullptr in + /// dependent contexts. + /// \param PreInits Helper preinits statements for the loop nest. + static OMPSplitDirective *Create(const ASTContext &C, SourceLocation StartLoc, + SourceLocation EndLoc, + ArrayRef<OMPClause *> Clauses, + unsigned NumLoops, Stmt *AssociatedStmt, + Stmt *TransformedStmt, Stmt *PreInits); + + /// Build an empty '#pragma omp split' AST node for deserialization. + /// + /// \param C Context of the AST. + /// \param NumClauses Number of clauses to allocate. + /// \param NumLoops Number of associated loops to allocate. + static OMPSplitDirective *CreateEmpty(const ASTContext &C, + unsigned NumClauses, unsigned NumLoops); + + /// Gets/sets the associated loops after the transformation, i.e. after + /// de-sugaring. + Stmt *getTransformedStmt() const { + return Data->getChildren()[TransformedStmtOffset]; + } + + /// Return preinits statement. + Stmt *getPreInits() const { return Data->getChildren()[PreInitsOffset]; } + + static bool classof(const Stmt *T) { + return T->getStmtClass() == OMPSplitDirectiveClass; + } +}; + /// This represents '#pragma omp scan' directive. /// /// \code diff --git a/clang/include/clang/ASTMatchers/ASTMatchers.h b/clang/include/clang/ASTMatchers/ASTMatchers.h index e8e7643e0dddd..87b6dbefa7a62 100644 --- a/clang/include/clang/ASTMatchers/ASTMatchers.h +++ b/clang/include/clang/ASTMatchers/ASTMatchers.h @@ -8781,6 +8781,26 @@ extern const internal::VariadicDynCastAllOfMatcher<Stmt, OMPTargetUpdateDirective> ompTargetUpdateDirective; +/// Matches any ``#pragma omp split`` executable directive. +/// +/// Given +/// +/// \code +/// #pragma omp split counts(2, omp_fill) +/// for (int i = 0; i < n; ++i) {} +/// \endcode +/// +/// ``ompSplitDirective()`` matches the split directive. +extern const internal::VariadicDynCastAllOfMatcher<Stmt, OMPSplitDirective> + ompSplitDirective; + +/// Matches OpenMP ``counts`` clause used by ``#pragma omp split``. +/// +/// Given ``#pragma omp split counts(1, 2, omp_fill)``, ``ompCountsClause()`` +/// matches the ``counts`` clause node. +extern const internal::VariadicDynCastAllOfMatcher<OMPClause, OMPCountsClause> + ompCountsClause; + /// Matches OpenMP ``default`` clause. /// /// Given diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td index d5904bd1d6f26..71d504c659cc2 100644 --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -11176,6 +11176,8 @@ def err_omp_bind_required_on_loop : Error< "construct">; def err_omp_loop_reduction_clause : Error< "'reduction' clause not allowed with '#pragma omp loop bind(teams)'">; +def err_omp_split_counts_not_one_omp_fill : Error< + "exactly one 'omp_fill' must appear in the 'counts' clause">; def warn_break_binds_to_switch : Warning< "'break' is bound to loop, GCC binds it to switch">, InGroup<GccCompat>; diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td index 61d76bafdfcde..e166894ea024b 100644 --- a/clang/include/clang/Basic/StmtNodes.td +++ b/clang/include/clang/Basic/StmtNodes.td @@ -244,6 +244,7 @@ def OMPTileDirective : StmtNode<OMPCanonicalLoopNestTransformationDirective>; def OMPStripeDirective : StmtNode<OMPCanonicalLoopNestTransformationDirective>; def OMPUnrollDirective : StmtNode<OMPCanonicalLoopNestTransformationDirective>; def OMPReverseDirective : StmtNode<OMPCanonicalLoopNestTransformationDirective>; +def OMPSplitDirective : StmtNode<OMPCanonicalLoopNestTransformationDirective>; def OMPInterchangeDirective : StmtNode<OMPCanonicalLoopNestTransformationDirective>; def OMPCanonicalLoopSequenceTransformationDirective diff --git a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h index 08a3d88ee6a36..bd313d37cc4b5 100644 --- a/clang/include/clang/Parse/Parser.h +++ b/clang/include/clang/Parse/Parser.h @@ -6812,6 +6812,9 @@ class Parser : public CodeCompletionHandler { /// Parses the 'sizes' clause of a '#pragma omp tile' directive. OMPClause *ParseOpenMPSizesClause(); + /// Parses the 'counts' clause of a '#pragma omp split' directive. + OMPClause *ParseOpenMPCountsClause(); + /// Parses the 'permutation' clause of a '#pragma omp interchange' directive. OMPClause *ParseOpenMPPermutationClause(); diff --git a/clang/include/clang/Sema/SemaOpenMP.h b/clang/include/clang/Sema/SemaOpenMP.h index 7853f29f98c25..3621ce96b8724 100644 --- a/clang/include/clang/Sema/SemaOpenMP.h +++ b/clang/include/clang/Sema/SemaOpenMP.h @@ -42,6 +42,7 @@ class FunctionScopeInfo; class DeclContext; class DeclGroupRef; +class EnumConstantDecl; class ParsedAttr; class Scope; @@ -457,6 +458,11 @@ class SemaOpenMP : public SemaBase { /// Called on well-formed '#pragma omp reverse'. StmtResult ActOnOpenMPReverseDirective(Stmt *AStmt, SourceLocation StartLoc, SourceLocation EndLoc); + /// Called on well-formed '#pragma omp split' after parsing of its + /// associated statement. + StmtResult ActOnOpenMPSplitDirective(ArrayRef<OMPClause *> Clauses, + Stmt *AStmt, SourceLocation StartLoc, + SourceLocation EndLoc); /// Called on well-formed '#pragma omp interchange' after parsing of its /// clauses and the associated statement. StmtResult ActOnOpenMPInterchangeDirective(ArrayRef<OMPClause *> Clauses, @@ -911,6 +917,12 @@ class SemaOpenMP : public SemaBase { SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc); + /// Called on well-formed 'counts' clause after parsing its arguments. + OMPClause * + ActOnOpenMPCountsClause(ArrayRef<Expr *> CountExprs, SourceLocation StartLoc, + SourceLocation LParenLoc, SourceLocation EndLoc, + std::optional<unsigned> FillIdx, + SourceLocation FillLoc, unsigned FillCount); /// Called on well-form 'permutation' clause after parsing its arguments. OMPClause *ActOnOpenMPPermutationClause(ArrayRef<Expr *> PermExprs, SourceLocation StartLoc, diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h index 783cd82895a90..9b798ed484454 100644 --- a/clang/include/clang/Serialization/ASTBitCodes.h +++ b/clang/include/clang/Serialization/ASTBitCodes.h @@ -1965,6 +1965,7 @@ enum StmtCode { STMP_OMP_STRIPE_DIRECTIVE, STMT_OMP_UNROLL_DIRECTIVE, STMT_OMP_REVERSE_DIRECTIVE, + STMT_OMP_SPLIT_DIRECTIVE, STMT_OMP_INTERCHANGE_DIRECTIVE, STMT_OMP_FUSE_DIRECTIVE, STMT_OMP_FOR_DIRECTIVE, diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp index d4826c3c6edca..3a35e17aff40b 100644 --- a/clang/lib/AST/OpenMPClause.cpp +++ b/clang/lib/AST/OpenMPClause.cpp @@ -15,10 +15,12 @@ #include "clang/AST/Attr.h" #include "clang/AST/Decl.h" #include "clang/AST/DeclOpenMP.h" +#include "clang/AST/Expr.h" #include "clang/AST/ExprOpenMP.h" #include "clang/Basic/LLVM.h" #include "clang/Basic/OpenMPKinds.h" #include "clang/Basic/TargetInfo.h" +#include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/ErrorHandling.h" #include <algorithm> @@ -986,6 +988,26 @@ OMPSizesClause *OMPSizesClause::CreateEmpty(const ASTContext &C, return new (Mem) OMPSizesClause(NumSizes); } +OMPCountsClause *OMPCountsClause::Create( + const ASTContext &C, SourceLocation StartLoc, SourceLocation LParenLoc, + SourceLocation EndLoc, ArrayRef<Expr *> Counts, + std::optional<unsigned> FillIdx, SourceLocation FillLoc) { + OMPCountsClause *Clause = CreateEmpty(C, Counts.size()); + Clause->setLocStart(StartLoc); + Clause->setLParenLoc(LParenLoc); + Clause->setLocEnd(EndLoc); + Clause->setCountsRefs(Counts); + Clause->setOmpFillIndex(FillIdx); + Clause->setOmpFillLoc(FillLoc); + return Clause; +} + +OMPCountsClause *OMPCountsClause::CreateEmpty(const ASTContext &C, + unsigned NumCounts) { + void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(NumCounts)); + return new (Mem) OMPCountsClause(NumCounts); +} + OMPPermutationClause *OMPPermutationClause::Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation LParenLoc, @@ -1984,6 +2006,19 @@ void OMPClausePrinter::VisitOMPSizesClause(OMPSizesClause *Node) { OS << ")"; } +void OMPClausePrinter::VisitOMPCountsClause(OMPCountsClause *Node) { + OS << "counts("; + std::optional<unsigned> FillIdx = Node->getOmpFillIndex(); + ArrayRef<Expr *> Refs = Node->getCountsRefs(); + llvm::interleaveComma(llvm::seq<unsigned>(Refs.size()), OS, [&](unsigned I) { + if (FillIdx && I == *FillIdx) + OS << "omp_fill"; + else + Refs[I]->printPretty(OS, nullptr, Policy, 0); + }); + OS << ")"; +} + void OMPClausePrinter::VisitOMPPermutationClause(OMPPermutationClause *Node) { OS << "permutation("; llvm::interleaveComma(Node->getArgsRefs(), OS, [&](const Expr *E) { diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp index a5b0cd3786a28..9d6b315effb41 100644 --- a/clang/lib/AST/StmtOpenMP.cpp +++ b/clang/lib/AST/StmtOpenMP.cpp @@ -552,6 +552,27 @@ OMPInterchangeDirective::CreateEmpty(const ASTContext &C, unsigned NumClauses, SourceLocation(), SourceLocation(), NumLoops); } +OMPSplitDirective * +OMPSplitDirective::Create(const ASTContext &C, SourceLocation StartLoc, + SourceLocation EndLoc, ArrayRef<OMPClause *> Clauses, + unsigned NumLoops, Stmt *AssociatedStmt, + Stmt *TransformedStmt, Stmt *PreInits) { + OMPSplitDirective *Dir = createDirective<OMPSplitDirective>( + C, Clauses, AssociatedStmt, TransformedStmtOffset + 1, StartLoc, EndLoc, + NumLoops); + Dir->setTransformedStmt(TransformedStmt); + Dir->setPreInits(PreInits); + return Dir; +} + +OMPSplitDirective *OMPSplitDirective::CreateEmpty(const ASTContext &C, + unsigned NumClauses, + unsigned NumLoops) { + return createEmptyDirective<OMPSplitDirective>( + C, NumClauses, /*HasAssociatedStmt=*/true, TransformedStmtOffset + 1, + SourceLocation(), SourceLocation(), NumLoops); +} + OMPFuseDirective *OMPFuseDirective::Create( const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, ArrayRef<OMPClause *> Clauses, unsigned NumGeneratedTopLevelLoops, diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp index 4d364fdcd5502..e0b930ba0a21a 100644 --- a/clang/lib/AST/StmtPrinter.cpp +++ b/clang/lib/AST/StmtPrinter.cpp @@ -800,6 +800,11 @@ void StmtPrinter::VisitOMPInterchangeDirective(OMPInterchangeDirective *Node) { PrintOMPExecutableDirective(Node); } +void StmtPrinter::VisitOMPSplitDirective(OMPSplitDirective *Node) { + Indent() << "#pragma omp split"; + PrintOMPExecutableDirective(Node); +} + void StmtPrinter::VisitOMPFuseDirective(OMPFuseDirective *Node) { Indent() << "#pragma omp fuse"; PrintOMPExecutableDirective(Node); diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp index e8c1f8a8ecb5f..c75652e5c1dd3 100644 --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -498,6 +498,12 @@ void OMPClauseProfiler::VisitOMPSizesClause(const OMPSizesClause *C) { Profiler->VisitExpr(E); } +void OMPClauseProfiler::VisitOMPCountsClause(const OMPCountsClause *C) { + for (auto *E : C->getCountsRefs()) + if ... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/190397 _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
