llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-openmp

Author: Julian Brown (jtb20)

<details>
<summary>Changes</summary>

This patch series outlines a new approach to implementing OpenMP 6.0 
"taskgraph" support in LLVM. Key differences to the previously-posted "record 
and replay" implementation are as follows.

  - The task/taskdata structures and the dependencies between them are 
duplicated whilst recording a taskgraph, keeping the existing runtime 
dependency handling unaffected by the taskgraph implementation -- e.g. during 
runtime execution, it is valid for output dependencies can be dropped as soon 
as the producing task completes.  This separation is intended to eliminate a 
class of race conditions, where tasks which complete unpredictably might or 
might not be marked as depending on a subsequent task.

  - For two reasons, a new set of entry points is used for recording tasks 
within a taskgraph.  The first reason is to pass extra information from the 
compiler relevant to the taskgraph-recording case -- e.g. the 
__kmpc_taskgraph_task entry point has extra arguments relating to shared data.  
The second reason is to reduce the potential overhead of the taskgraph 
implementation on the rest of the runtime. (1)

  - The dependencies between tasks in a taskgraph are processed by static 
analysis: the high-level process is akin to turning data dependencies between 
tasks into control-flow dependencies.  This is done by building a set of 
successors and predecessors for each recorded task, then decomposing the 
resulting DAG into parallel and sequential regions.  In the (presumed 
relatively unlikely, in real-world code) case that the graph is irreducible, a 
further set of analyses and transformations is done, and the 
parallel-sequential decomposition is run again. (2)

  The output of this process is a set of nested kmp_taskgraph_region structures 
-- parallel or sequential (with some number of children), or nodes representing 
a single task.  The two phases alternate until we obtain a single, top-level 
region.

  - Replaying a taskgraph processed in this way on the CPU involves another set 
of linked structures, of type kmp_taskgraph_exec_descr.  These form a kind of 
trace of a traversal over the kmp_taskgraph_region structure, so that a pointer 
to a kmp_taskgraph_exec_descr is somewhat equivalent to a "program counter". (3)

  - Recorded taskgraphs are now located directly by using a handle passed in 
from the user's compiled program, rather than using a linked list or hashtable 
to find taskgraph records to replay keyed by an index.

(1) A third intention is to capture OpenMP semantics at a slightly higher 
level: in particular when we come to add offload target tasks to this 
implementation, those will also use new API entry points to hopefully allow 
dependencies to be handled entirely on the GPU, rather than by being wrapped in 
a host task.

(2) This process will take some time, but I have made some effort to make it 
efficient.  E.g. unnecessary allocations and deallocations are kept to a 
minimum by recycling kmp_taskgraph_region_dep_t structures (which are always 
the same size), or by allocating kmp_depnode_t all together in a single block 
(in __kmp_build_taskgraph).

(3) The intention is that GPU/offload execution will take the nested 
kmp_taskgraph_region structure (potentially containing intermixed target tasks 
and host tasks) and map it in some way appropriate for a GPU graph-execution 
API or (a suitably extended) liboffload GPU backend.

There is also an implementation of the "replayable" clause, but not (yet) the 
"saved" modifier.

Some care has been taken around the handling of implicit taskgroups: in 
particular, a taskgraph can contain two back-to-back taskloops, each of which 
has a reduction:

    #pragma omp taskgraph
    {
        #pragma omp taskloop reduction(+: var1)
        { var1 += ...; }

        #pragma omp taskloop reduction(+: var2)
        { var2 += var1; }
    }

This seems legal, but means that we still need to create and destroy taskgroup 
structures if we have reductions on a taskloop within a replayed taskgraph so 
that the first reduction result can be used safely in the second loop.  (Or 
perhaps we could retain/reuse the taskgroup structures: that's not done yet.)

The patch series goes some way towards full thread safety, but there are still 
problems to be addressed.  In principle we could perhaps have a 
kmp_taskgraph_exec_descr set per-thread, pointing back to a shared 
kmp_taskgraph_record/kmp_taskgraph_region structure, but in practice we'd 
probably need to duplicate the underlying task/taskdata structures too.  The 
global shared state is all gone now.

(This patch series builds on top of previous work, which is also included in 
the commit series. The bulk of the new work is in the top four patches.)

---

Patch is 413.46 KiB, truncated to 20.00 KiB below, full version: 
https://github.com/llvm/llvm-project/pull/188765.diff


85 Files Affected:

- (modified) clang/bindings/python/clang/cindex.py (+3) 
- (modified) clang/include/clang-c/Index.h (+4) 
- (modified) clang/include/clang/AST/OpenMPClause.h (+158) 
- (modified) clang/include/clang/AST/RecursiveASTVisitor.h (+24) 
- (modified) clang/include/clang/AST/StmtOpenMP.h (+49) 
- (modified) clang/include/clang/Basic/StmtNodes.td (+1) 
- (modified) clang/include/clang/Sema/SemaOpenMP.h (+19) 
- (modified) clang/include/clang/Serialization/ASTBitCodes.h (+1) 
- (modified) clang/lib/AST/OpenMPClause.cpp (+51) 
- (modified) clang/lib/AST/StmtOpenMP.cpp (+15) 
- (modified) clang/lib/AST/StmtPrinter.cpp (+5) 
- (modified) clang/lib/AST/StmtProfile.cpp (+20) 
- (modified) clang/lib/Basic/OpenMPKinds.cpp (+7) 
- (modified) clang/lib/CodeGen/CGOpenMPRuntime.cpp (+553-198) 
- (modified) clang/lib/CodeGen/CGOpenMPRuntime.h (+25-4) 
- (modified) clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp (+2) 
- (modified) clang/lib/CodeGen/CGStmt.cpp (+3) 
- (modified) clang/lib/CodeGen/CGStmtOpenMP.cpp (+68-10) 
- (modified) clang/lib/CodeGen/CodeGenFunction.h (+17) 
- (modified) clang/lib/Parse/ParseOpenMP.cpp (+16-1) 
- (modified) clang/lib/Sema/SemaExceptionSpec.cpp (+1) 
- (modified) clang/lib/Sema/SemaOpenMP.cpp (+139) 
- (modified) clang/lib/Sema/TreeTransform.h (+80) 
- (modified) clang/lib/Serialization/ASTReader.cpp (+26) 
- (modified) clang/lib/Serialization/ASTReaderStmt.cpp (+10) 
- (modified) clang/lib/Serialization/ASTWriter.cpp (+17) 
- (modified) clang/lib/Serialization/ASTWriterStmt.cpp (+6) 
- (modified) clang/lib/StaticAnalyzer/Core/ExprEngine.cpp (+1) 
- (added) clang/test/OpenMP/taskgraph_ast_print.cpp (+31) 
- (added) clang/test/OpenMP/taskgraph_codegen.cpp (+52) 
- (modified) clang/tools/libclang/CIndex.cpp (+12) 
- (modified) clang/tools/libclang/CXCursor.cpp (+3) 
- (modified) llvm/include/llvm/Frontend/OpenMP/OMP.td (+3) 
- (modified) llvm/include/llvm/Frontend/OpenMP/OMPKinds.def (+11) 
- (modified) openmp/runtime/CMakeLists.txt (+3-3) 
- (modified) openmp/runtime/src/kmp.h (+201-81) 
- (modified) openmp/runtime/src/kmp_config.h.cmake (+2-2) 
- (modified) openmp/runtime/src/kmp_debug.h (+14) 
- (modified) openmp/runtime/src/kmp_global.cpp (+1-12) 
- (modified) openmp/runtime/src/kmp_settings.cpp (+6-28) 
- (modified) openmp/runtime/src/kmp_taskdeps.cpp (+2986-266) 
- (modified) openmp/runtime/src/kmp_taskdeps.h (+17-31) 
- (modified) openmp/runtime/src/kmp_tasking.cpp (+906-445) 
- (modified) openmp/runtime/test/CMakeLists.txt (+1-1) 
- (modified) openmp/runtime/test/lit.cfg (+2-2) 
- (modified) openmp/runtime/test/lit.site.cfg.in (+1-1) 
- (added) openmp/runtime/test/taskgraph/taskgraph_deps_1.cpp (+50) 
- (added) openmp/runtime/test/taskgraph/taskgraph_deps_10.cpp (+47) 
- (added) openmp/runtime/test/taskgraph/taskgraph_deps_11.cpp (+57) 
- (added) openmp/runtime/test/taskgraph/taskgraph_deps_12.cpp (+52) 
- (added) openmp/runtime/test/taskgraph/taskgraph_deps_13.cpp (+42) 
- (added) openmp/runtime/test/taskgraph/taskgraph_deps_14.cpp (+45) 
- (added) openmp/runtime/test/taskgraph/taskgraph_deps_15.cpp (+72) 
- (added) openmp/runtime/test/taskgraph/taskgraph_deps_16.cpp (+52) 
- (added) openmp/runtime/test/taskgraph/taskgraph_deps_17.cpp (+65) 
- (added) openmp/runtime/test/taskgraph/taskgraph_deps_18.cpp (+43) 
- (added) openmp/runtime/test/taskgraph/taskgraph_deps_19.cpp (+48) 
- (added) openmp/runtime/test/taskgraph/taskgraph_deps_2.cpp (+55) 
- (added) openmp/runtime/test/taskgraph/taskgraph_deps_20.cpp (+48) 
- (added) openmp/runtime/test/taskgraph/taskgraph_deps_21.cpp (+49) 
- (added) openmp/runtime/test/taskgraph/taskgraph_deps_22.cpp (+67) 
- (added) openmp/runtime/test/taskgraph/taskgraph_deps_23.cpp (+100) 
- (added) openmp/runtime/test/taskgraph/taskgraph_deps_24.cpp (+77) 
- (added) openmp/runtime/test/taskgraph/taskgraph_deps_25.cpp (+86) 
- (added) openmp/runtime/test/taskgraph/taskgraph_deps_26.cpp (+58) 
- (added) openmp/runtime/test/taskgraph/taskgraph_deps_27.cpp (+60) 
- (added) openmp/runtime/test/taskgraph/taskgraph_deps_3.cpp (+77) 
- (added) openmp/runtime/test/taskgraph/taskgraph_deps_4.cpp (+73) 
- (added) openmp/runtime/test/taskgraph/taskgraph_deps_5.cpp (+60) 
- (added) openmp/runtime/test/taskgraph/taskgraph_deps_6.cpp (+56) 
- (added) openmp/runtime/test/taskgraph/taskgraph_deps_7.cpp (+56) 
- (added) openmp/runtime/test/taskgraph/taskgraph_deps_8.cpp (+36) 
- (added) openmp/runtime/test/taskgraph/taskgraph_deps_9.cpp (+44) 
- (removed) openmp/runtime/test/tasking/omp_record_replay.cpp (-48) 
- (removed) openmp/runtime/test/tasking/omp_record_replay_deps.cpp (-63) 
- (removed) openmp/runtime/test/tasking/omp_record_replay_deps_multi_succ.cpp 
(-56) 
- (removed) openmp/runtime/test/tasking/omp_record_replay_multiTDGs.cpp (-76) 
- (removed) openmp/runtime/test/tasking/omp_record_replay_print_dot.cpp (-80) 
- (added) openmp/runtime/test/tasking/omp_record_replay_random_id.cpp (+47) 
- (added) openmp/runtime/test/tasking/omp_record_replay_reset.cpp (+47) 
- (removed) openmp/runtime/test/tasking/omp_record_replay_taskloop.cpp (-50) 
- (added) openmp/runtime/test/tasking/omp_taskgraph.cpp (+35) 
- (added) openmp/runtime/test/tasking/omp_taskgraph_deps.cpp (+52) 
- (added) openmp/runtime/test/tasking/omp_taskgraph_multiTDGs.cpp (+66) 
- (added) openmp/runtime/test/tasking/omp_taskgraph_taskloop.cpp (+39) 


``````````diff
diff --git a/clang/bindings/python/clang/cindex.py 
b/clang/bindings/python/clang/cindex.py
index 1896a0a9c1c34..093bfc669b82f 100644
--- a/clang/bindings/python/clang/cindex.py
+++ b/clang/bindings/python/clang/cindex.py
@@ -1448,6 +1448,9 @@ def is_unexposed(self):
     # OpenMP fuse directive.
     OMP_FUSE_DIRECTIVE = 311
 
+    # OpenMP taskgraph directive.
+    OMP_TASKGRAPH_DIRECTIVE = 312
+
     # OpenACC Compute Construct.
     OPEN_ACC_COMPUTE_DIRECTIVE = 320
 
diff --git a/clang/include/clang-c/Index.h b/clang/include/clang-c/Index.h
index 203634c80d82a..31a43260edcec 100644
--- a/clang/include/clang-c/Index.h
+++ b/clang/include/clang-c/Index.h
@@ -2166,6 +2166,10 @@ enum CXCursorKind {
    */
   CXCursor_OMPFuseDirective = 311,
 
+  /** OpenMP taskgraph directive.
+   */
+  CXCursor_OMPTaskgraphDirective = 312,
+
   /** OpenACC Compute Construct.
    */
   CXCursor_OpenACCComputeConstruct = 320,
diff --git a/clang/include/clang/AST/OpenMPClause.h 
b/clang/include/clang/AST/OpenMPClause.h
index af5d3f4698eda..0860aca973516 100644
--- a/clang/include/clang/AST/OpenMPClause.h
+++ b/clang/include/clang/AST/OpenMPClause.h
@@ -1939,6 +1939,75 @@ class OMPSelfMapsClause final : public OMPClause {
   }
 };
 
+/// This represents a 'replayable' clause in the '#pragma omp target',
+// '#pragma omp target enter data', '#pragma omp target exit data',
+// '#pragma omp target update', '#pragma omp task', '#pragma omp taskloop' or
+// '#pragma omp taskwait' directive.
+///
+/// \code
+/// #pragma omp task replayable(1)
+/// \endcode
+/// In this example directive '#pragma omp task' has the 'replayable' clause.
+class OMPReplayableClause final : public OMPClause {
+public:
+  friend class OMPClauseReader;
+
+  /// Location of '('.
+  SourceLocation LParenLoc;
+
+  /// Condition of the 'replayable' clause.
+  Stmt *Condition = nullptr;
+
+  /// Set condition.
+  void setCondition(Expr *Cond) { Condition = Cond; }
+
+  /// Build 'replayable' clause.
+  ///
+  /// \param Cond Condition of the clause.
+  /// \param StartLoc Starting location of the clause.
+  /// \param LParenLoc Location of '('.
+  /// \param EndLoc Ending location of the clause.
+  OMPReplayableClause(Expr *Cond, SourceLocation StartLoc,
+                      SourceLocation LParenLoc, SourceLocation EndLoc)
+      : OMPClause(llvm::omp::OMPC_replayable, StartLoc, EndLoc),
+        LParenLoc(LParenLoc), Condition(Cond) {}
+
+  /// Build an empty clause.
+  OMPReplayableClause()
+      : OMPClause(llvm::omp::OMPC_replayable, SourceLocation(),
+                  SourceLocation()) {}
+
+  /// Sets the location of '('.
+  void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
+
+  /// Returns the location of '('.
+  SourceLocation getLParenLoc() const { return LParenLoc; }
+
+  /// Returns condition.
+  Expr *getCondition() const { return cast_or_null<Expr>(Condition); }
+
+  child_range children() {
+    if (Condition)
+      return child_range(&Condition, &Condition + 1);
+    return child_range(child_iterator(), child_iterator());
+  }
+
+  const_child_range children() const {
+    if (Condition)
+      return const_child_range(&Condition, &Condition + 1);
+    return const_child_range(const_child_iterator(), const_child_iterator());
+  }
+
+  child_range used_children();
+  const_child_range used_children() const {
+    return const_cast<OMPReplayableClause *>(this)->used_children();
+  }
+
+  static bool classof(const OMPClause *T) {
+    return T->getClauseKind() == llvm::omp::OMPC_replayable;
+  }
+};
+
 /// This represents 'at' clause in the '#pragma omp error' directive
 ///
 /// \code
@@ -8440,6 +8509,95 @@ class OMPIsDevicePtrClause final
   }
 };
 
+/// This represents clause 'graph_id' in the '#pragma omp taskgraph"
+/// directives.
+///
+/// \code
+/// #pragma omp taskgraph graph_id(a)
+class OMPGraphIdClause final
+    : public OMPOneStmtClause<llvm::omp::OMPC_graph_id, OMPClause>,
+      public OMPClauseWithPreInit {
+  friend class OMPClauseReader;
+
+  /// Set condition.
+  void setCondition(Expr *Cond) { setStmt(Cond); }
+
+public:
+  /// Build 'graph_id' clause with condition \a Cond.
+  ///
+  /// \param Cond Condition of the clause.
+  /// \param HelperCond Helper condition for the construct.
+  /// \param CaptureRegion Innermost OpenMP region where expressions in this
+  /// clause must be captured.
+  /// \param StartLoc Starting location of the clause.
+  /// \param LParenLoc Location of '('.
+  /// \param EndLoc Ending location of the clause.
+  OMPGraphIdClause(Expr *Cond, Stmt *HelperCond,
+                   OpenMPDirectiveKind CaptureRegion, SourceLocation StartLoc,
+                   SourceLocation LParenLoc, SourceLocation EndLoc)
+      : OMPOneStmtClause(Cond, StartLoc, LParenLoc, EndLoc),
+        OMPClauseWithPreInit(this) {
+    setPreInitStmt(HelperCond, CaptureRegion);
+  }
+
+  /// Build an empty clause.
+  OMPGraphIdClause() : OMPOneStmtClause(), OMPClauseWithPreInit(this) {}
+
+  /// Returns condition.
+  Expr *getCondition() const { return getStmtAs<Expr>(); }
+
+  child_range used_children();
+  const_child_range used_children() const {
+    auto Children = const_cast<OMPGraphIdClause *>(this)->used_children();
+    return const_child_range(Children.begin(), Children.end());
+  }
+};
+
+// This represents clause 'graph_reset' in the '#pragma omp taskgraph"
+/// directives.
+///
+/// \code
+/// #pragma omp taskgraph graph_reset(true)
+class OMPGraphResetClause final
+    : public OMPOneStmtClause<llvm::omp::OMPC_graph_reset, OMPClause>,
+      public OMPClauseWithPreInit {
+  friend class OMPClauseReader;
+
+  /// Set condition.
+  void setCondition(Expr *Cond) { setStmt(Cond); }
+
+public:
+  /// Build 'graph_reset' clause with condition \a Cond.
+  ///
+  /// \param Cond Condition of the clause.
+  /// \param HelperCond Helper condition for the construct.
+  /// \param CaptureRegion Innermost OpenMP region where expressions in this
+  /// clause must be captured.
+  /// \param StartLoc Starting location of the clause.
+  /// \param LParenLoc Location of '('.
+  /// \param EndLoc Ending location of the clause.
+  OMPGraphResetClause(Expr *Cond, Stmt *HelperCond,
+                      OpenMPDirectiveKind CaptureRegion,
+                      SourceLocation StartLoc, SourceLocation LParenLoc,
+                      SourceLocation EndLoc)
+      : OMPOneStmtClause(Cond, StartLoc, LParenLoc, EndLoc),
+        OMPClauseWithPreInit(this) {
+    setPreInitStmt(HelperCond, CaptureRegion);
+  }
+
+  /// Build an empty clause.
+  OMPGraphResetClause() : OMPOneStmtClause(), OMPClauseWithPreInit(this) {}
+
+  /// Returns condition.
+  Expr *getCondition() const { return getStmtAs<Expr>(); }
+
+  child_range used_children();
+  const_child_range used_children() const {
+    auto Children = const_cast<OMPGraphResetClause *>(this)->used_children();
+    return const_child_range(Children.begin(), Children.end());
+  }
+};
+
 /// This represents clause 'has_device_ptr' in the '#pragma omp ...'
 /// directives.
 ///
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h 
b/clang/include/clang/AST/RecursiveASTVisitor.h
index ce6ad723191e0..c327617c21b74 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3255,6 +3255,9 @@ DEF_TRAVERSE_STMT(OMPBarrierDirective,
 DEF_TRAVERSE_STMT(OMPTaskwaitDirective,
                   { TRY_TO(TraverseOMPExecutableDirective(S)); })
 
+DEF_TRAVERSE_STMT(OMPTaskgraphDirective,
+                  { TRY_TO(TraverseOMPExecutableDirective(S)); })
+
 DEF_TRAVERSE_STMT(OMPTaskgroupDirective,
                   { TRY_TO(TraverseOMPExecutableDirective(S)); })
 
@@ -3631,6 +3634,12 @@ bool 
RecursiveASTVisitor<Derived>::VisitOMPNowaitClause(OMPNowaitClause *C) {
   return true;
 }
 
+template <typename Derived>
+bool 
RecursiveASTVisitor<Derived>::VisitOMPReplayableClause(OMPReplayableClause *C) {
+  TRY_TO(TraverseStmt(C->getCondition()));
+  return true;
+}
+
 template <typename Derived>
 bool RecursiveASTVisitor<Derived>::VisitOMPUntiedClause(OMPUntiedClause *) {
   return true;
@@ -4123,6 +4132,21 @@ bool 
RecursiveASTVisitor<Derived>::VisitOMPIsDevicePtrClause(
   return true;
 }
 
+template <typename Derived>
+bool RecursiveASTVisitor<Derived>::VisitOMPGraphIdClause(OMPGraphIdClause *C) {
+  TRY_TO(VisitOMPClauseWithPreInit(C));
+  TRY_TO(TraverseStmt(C->getCondition()));
+  return true;
+}
+
+template <typename Derived>
+bool RecursiveASTVisitor<Derived>::VisitOMPGraphResetClause(
+    OMPGraphResetClause *C) {
+  TRY_TO(VisitOMPClauseWithPreInit(C));
+  TRY_TO(TraverseStmt(C->getCondition()));
+  return true;
+}
+
 template <typename Derived>
 bool RecursiveASTVisitor<Derived>::VisitOMPHasDeviceAddrClause(
     OMPHasDeviceAddrClause *C) {
diff --git a/clang/include/clang/AST/StmtOpenMP.h 
b/clang/include/clang/AST/StmtOpenMP.h
index bc6aeaa8d143c..be4d33c783800 100644
--- a/clang/include/clang/AST/StmtOpenMP.h
+++ b/clang/include/clang/AST/StmtOpenMP.h
@@ -2760,6 +2760,55 @@ class OMPTaskwaitDirective : public 
OMPExecutableDirective {
   }
 };
 
+/// This represents '#pragma omp taskgraph' directive.
+/// Available with OpenMP 6.0.
+///
+/// \code
+/// #pragma omp taskgraph
+/// \endcode
+///
+class OMPTaskgraphDirective final : public OMPExecutableDirective {
+  friend class ASTStmtReader;
+  friend class OMPExecutableDirective;
+  /// Build directive with the given start and end location.
+  ///
+  /// \param StartLoc Starting location of the directive kind.
+  /// \param EndLoc Ending location of the directive.
+  ///
+  OMPTaskgraphDirective(SourceLocation StartLoc, SourceLocation EndLoc)
+      : OMPExecutableDirective(OMPTaskgraphDirectiveClass,
+                               llvm::omp::OMPD_taskgraph, StartLoc, EndLoc) {}
+
+  /// Build an empty directive.
+  ///
+  explicit OMPTaskgraphDirective()
+      : OMPExecutableDirective(OMPTaskgraphDirectiveClass,
+                               llvm::omp::OMPD_taskgraph, SourceLocation(),
+                               SourceLocation()) {}
+
+public:
+  /// Creates directive.
+  ///
+  /// \param C AST context.
+  /// \param StartLoc Starting location of the directive kind.
+  /// \param EndLoc Ending Location of the directive.
+  ///
+  static OMPTaskgraphDirective *
+  Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
+         ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt);
+
+  /// Creates an empty directive.
+  ///
+  /// \param C AST context.
+  ///
+  static OMPTaskgraphDirective *CreateEmpty(const ASTContext &C,
+                                            unsigned NumClauses, EmptyShell);
+
+  static bool classof(const Stmt *T) {
+    return T->getStmtClass() == OMPTaskgraphDirectiveClass;
+  }
+};
+
 /// This represents '#pragma omp taskgroup' directive.
 ///
 /// \code
diff --git a/clang/include/clang/Basic/StmtNodes.td 
b/clang/include/clang/Basic/StmtNodes.td
index b196382025c95..19cb832782195 100644
--- a/clang/include/clang/Basic/StmtNodes.td
+++ b/clang/include/clang/Basic/StmtNodes.td
@@ -264,6 +264,7 @@ def OMPTaskDirective : StmtNode<OMPExecutableDirective>;
 def OMPTaskyieldDirective : StmtNode<OMPExecutableDirective>;
 def OMPBarrierDirective : StmtNode<OMPExecutableDirective>;
 def OMPTaskwaitDirective : StmtNode<OMPExecutableDirective>;
+def OMPTaskgraphDirective : StmtNode<OMPExecutableDirective>;
 def OMPTaskgroupDirective : StmtNode<OMPExecutableDirective>;
 def OMPFlushDirective : StmtNode<OMPExecutableDirective>;
 def OMPDepobjDirective : StmtNode<OMPExecutableDirective>;
diff --git a/clang/include/clang/Sema/SemaOpenMP.h 
b/clang/include/clang/Sema/SemaOpenMP.h
index 7853f29f98c25..6901740a03df7 100644
--- a/clang/include/clang/Sema/SemaOpenMP.h
+++ b/clang/include/clang/Sema/SemaOpenMP.h
@@ -557,6 +557,10 @@ class SemaOpenMP : public SemaBase {
   /// Called on well-formed '\#pragma omp barrier'.
   StmtResult ActOnOpenMPBarrierDirective(SourceLocation StartLoc,
                                          SourceLocation EndLoc);
+  /// Called on well-formed '\#pragma omp taskgraph'.
+  StmtResult ActOnOpenMPTaskgraphDirective(ArrayRef<OMPClause *> Clauses,
+                                           Stmt *AStmt, SourceLocation 
StartLoc,
+                                           SourceLocation EndLoc);
   /// Called on well-formed '\#pragma omp taskwait'.
   StmtResult ActOnOpenMPTaskwaitDirective(ArrayRef<OMPClause *> Clauses,
                                           SourceLocation StartLoc,
@@ -939,6 +943,15 @@ class SemaOpenMP : public SemaBase {
   ActOnOpenMPOrderedClause(SourceLocation StartLoc, SourceLocation EndLoc,
                            SourceLocation LParenLoc = SourceLocation(),
                            Expr *NumForLoops = nullptr);
+  /// Called on well-formed 'graph_id' clause.
+  OMPClause *ActOnOpenMPGraphIdClause(Expr *Condition, SourceLocation StartLoc,
+                                      SourceLocation LParenLoc,
+                                      SourceLocation EndLoc);
+  /// Called on well-formed 'graph_reset' clause.
+  OMPClause *ActOnOpenMPGraphResetClause(Expr *Condition,
+                                         SourceLocation StartLoc,
+                                         SourceLocation LParenLoc,
+                                         SourceLocation EndLoc);
   /// Called on well-formed 'grainsize' clause.
   OMPClause *ActOnOpenMPGrainsizeClause(OpenMPGrainsizeClauseModifier Modifier,
                                         Expr *Size, SourceLocation StartLoc,
@@ -1148,6 +1161,12 @@ class SemaOpenMP : public SemaBase {
   OMPClause *ActOnOpenMPSelfMapsClause(SourceLocation StartLoc,
                                        SourceLocation EndLoc);
 
+  /// Called on well-formed 'replayable' clause.
+  OMPClause *ActOnOpenMPReplayableClause(SourceLocation StartLoc,
+                                         SourceLocation EndLoc,
+                                         SourceLocation LParenLoc,
+                                         Expr *Condition);
+
   /// Called on well-formed 'at' clause.
   OMPClause *ActOnOpenMPAtClause(OpenMPAtClauseKind Kind,
                                  SourceLocation KindLoc,
diff --git a/clang/include/clang/Serialization/ASTBitCodes.h 
b/clang/include/clang/Serialization/ASTBitCodes.h
index 5db0b08f877ce..a40f9a6eba4fa 100644
--- a/clang/include/clang/Serialization/ASTBitCodes.h
+++ b/clang/include/clang/Serialization/ASTBitCodes.h
@@ -1981,6 +1981,7 @@ enum StmtCode {
   STMT_OMP_ERROR_DIRECTIVE,
   STMT_OMP_BARRIER_DIRECTIVE,
   STMT_OMP_TASKWAIT_DIRECTIVE,
+  STMT_OMP_TASKGRAPH_DIRECTIVE,
   STMT_OMP_FLUSH_DIRECTIVE,
   STMT_OMP_DEPOBJ_DIRECTIVE,
   STMT_OMP_SCAN_DIRECTIVE,
diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp
index d4826c3c6edca..a2a04f494fc32 100644
--- a/clang/lib/AST/OpenMPClause.cpp
+++ b/clang/lib/AST/OpenMPClause.cpp
@@ -91,6 +91,10 @@ const OMPClauseWithPreInit *OMPClauseWithPreInit::get(const 
OMPClause *C) {
     return static_cast<const OMPDeviceClause *>(C);
   case OMPC_grainsize:
     return static_cast<const OMPGrainsizeClause *>(C);
+  case OMPC_graph_id:
+    return static_cast<const OMPGraphIdClause *>(C);
+  case OMPC_graph_reset:
+    return static_cast<const OMPGraphResetClause *>(C);
   case OMPC_num_tasks:
     return static_cast<const OMPNumTasksClause *>(C);
   case OMPC_final:
@@ -252,6 +256,8 @@ const OMPClauseWithPostUpdate 
*OMPClauseWithPostUpdate::get(const OMPClause *C)
   case OMPC_thread_limit:
   case OMPC_priority:
   case OMPC_grainsize:
+  case OMPC_graph_id:
+  case OMPC_graph_reset:
   case OMPC_nogroup:
   case OMPC_num_tasks:
   case OMPC_hint:
@@ -320,12 +326,30 @@ OMPClause::child_range OMPNowaitClause::used_children() {
   return children();
 }
 
+OMPClause::child_range OMPReplayableClause::used_children() {
+  if (Condition)
+    return child_range(&Condition, &Condition + 1);
+  return children();
+}
+
 OMPClause::child_range OMPGrainsizeClause::used_children() {
   if (Stmt **C = getAddrOfExprAsWritten(getPreInitStmt()))
     return child_range(C, C + 1);
   return child_range(&Grainsize, &Grainsize + 1);
 }
 
+OMPClause::child_range OMPGraphIdClause::used_children() {
+  if (Stmt **C = getAddrOfExprAsWritten(getPreInitStmt()))
+    return child_range(C, C + 1);
+  return children();
+}
+
+OMPClause::child_range OMPGraphResetClause::used_children() {
+  if (Stmt **C = getAddrOfExprAsWritten(getPreInitStmt()))
+    return child_range(C, C + 1);
+  return children();
+}
+
 OMPClause::child_range OMPNumTasksClause::used_children() {
   if (Stmt **C = getAddrOfExprAsWritten(getPreInitStmt()))
     return child_range(C, C + 1);
@@ -2158,6 +2182,15 @@ void 
OMPClausePrinter::VisitOMPNowaitClause(OMPNowaitClause *Node) {
   }
 }
 
+void OMPClausePrinter::VisitOMPReplayableClause(OMPReplayableClause *Node) {
+  OS << "replayable";
+  if (auto *Cond = Node->getCondition()) {
+    OS << "(";
+    Cond->printPretty(OS, nullptr, Policy, 0);
+    OS << ")";
+  }
+}
+
 void OMPClausePrinter::VisitOMPUntiedClause(OMPUntiedClause *) {
   OS << "untied";
 }
@@ -2334,6 +2367,24 @@ void 
OMPClausePrinter::VisitOMPGrainsizeClause(OMPGrainsizeClause *Node) {
   OS << ")";
 }
 
+void OMPClausePrinter::VisitOMPGraphIdClause(OMPGraphIdClause *Node) {
+  OS << "graph_id";
+  if (Expr *E = Node->getCondition()) {
+    OS << "(";
+    E->printPretty(OS, nullptr, Policy, 0);
+    OS << ")";
+  }
+}
+
+void OMPClausePrinter::VisitOMPGraphResetClause(OMPGraphResetClause *Node) {
+  OS << "graph_reset";
+  if (Expr *E = Node->getCondition()) {
+    OS << "(";
+    E->printPretty(OS, nullptr, Policy, 0);
+    OS << ")";
+  }
+}
+
 void OMPClausePrinter::VisitOMPNumTasksClause(OMPNumTasksClause *Node) {
   OS << "num_tasks(";
   OpenMPNumTasksClauseModifier Modifier = Node->getModifier();
diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp
index a5b0cd3786a28..41effd494524c 100644
--- a/clang/lib/AST/StmtOpenMP.cpp
+++ b/clang/lib/AST/StmtOpenMP.cpp
@@ -945,6 +945,21 @@ OMPTaskwaitDirective 
*OMPTaskwaitDirective::CreateEmpty(const ASTContext &C,
   return createEmptyDirective<OMPTaskwaitDirective>(C, NumClauses);
 }
 
+OMPTaskgraphDirective *OMPTaskgraphDirective::Create(
+    const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
+    ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt) {
+  auto *Dir = createDirective<OMPTaskgraphDirective>(
+      C, Clauses, AssociatedStmt, /*NumChildren=*/1, StartLoc, EndLoc);
+  return Dir;
+}
+
+OMPTaskgraphDirective *OMPTaskgraphDirective::CreateEmpty(const ASTContext &C,
+                                                          unsigned NumClauses,
+                                                          EmptyShell) {
+  return createEmptyDirective<OMPTaskgraphDirective>(
+      C, NumClauses, /*HasAssociatedStmt=*/true, /*NumChildren=*/1);
+}
+
 OMPTaskgroupDirective *OMPTaskgroupDirective::Create(
     const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
     ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt, Expr *ReductionRef) {
diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp
index 4d364fdcd5502..f82f83613dc4d 100644
--- a/clang/lib/AST/StmtPrinter.cpp
+++ b/clang/lib/AST/StmtPrinter.cpp
@@ -904,6 +904,11 @@ void 
StmtPrinter::VisitOMPAssumeDirective(OMPAssumeDirective *Node) {
   PrintOMPExecutableDirective(Node);
 }
 
+void StmtPrinter::VisitOMPTaskgraphDirective(OMPTaskgraphDirective *Node) {
+  Indent() << "#pragma omp taskgraph";
+  PrintOMPExecutableDirective(Node);
+}
+
 void StmtPrinter::VisitOMPErrorDirective(OMPErrorDirective *Node) {
   Indent() << "#pragma omp error";
   PrintOMPExecutableDirective(Node);
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index dc7fd352a67b2..11f8f96bfa16b 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -600,6 +600,11 @@ void OMPClauseProfiler::VisitOMPNowaitClause(const 
OMPNowaitClause *C) {
     Profiler->VisitStmt(C->getCondition());
 }
 
+void OMPClauseProfiler::VisitOMPReplayableClause(const OMPReplayableClause *C) 
{
+  if (C->getCondition())
+    Profiler->VisitStmt(C->getCondition());
+}
+
 void OMPClauseProfiler::VisitOMPUntiedClause(const OMPUntiedClause *) {}
 
 void OMPClauseProfiler::VisitOMPMergeableClause(const OMPMergeableClause *) {}
@@ -910,6 +915,16 @@ void OMPClauseProfiler::VisitOMPGrainsizeClause(const ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/188765
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to