This revision was automatically updated to reflect the committed changes.
Closed by commit rGf37e8b0b831e: [Clang][OpenMP] Infix 
OMPLoopTransformationDirective abstract class. NFC. (authored by Meinersbur).

Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D111119/new/

https://reviews.llvm.org/D111119

Files:
  clang/include/clang/AST/StmtOpenMP.h
  clang/include/clang/Basic/StmtNodes.td
  clang/lib/AST/StmtOpenMP.cpp
  clang/lib/AST/StmtProfile.cpp
  clang/lib/CodeGen/CGStmtOpenMP.cpp
  clang/lib/Sema/SemaOpenMP.cpp
  clang/lib/Serialization/ASTReaderStmt.cpp
  clang/lib/Serialization/ASTWriterStmt.cpp
  clang/tools/libclang/CIndex.cpp

Index: clang/tools/libclang/CIndex.cpp
===================================================================
--- clang/tools/libclang/CIndex.cpp
+++ clang/tools/libclang/CIndex.cpp
@@ -2046,6 +2046,8 @@
   void VisitOMPLoopDirective(const OMPLoopDirective *D);
   void VisitOMPParallelDirective(const OMPParallelDirective *D);
   void VisitOMPSimdDirective(const OMPSimdDirective *D);
+  void
+  VisitOMPLoopTransformationDirective(const OMPLoopTransformationDirective *D);
   void VisitOMPTileDirective(const OMPTileDirective *D);
   void VisitOMPUnrollDirective(const OMPUnrollDirective *D);
   void VisitOMPForDirective(const OMPForDirective *D);
@@ -2901,12 +2903,17 @@
   VisitOMPLoopDirective(D);
 }
 
-void EnqueueVisitor::VisitOMPTileDirective(const OMPTileDirective *D) {
+void EnqueueVisitor::VisitOMPLoopTransformationDirective(
+    const OMPLoopTransformationDirective *D) {
   VisitOMPLoopBasedDirective(D);
 }
 
+void EnqueueVisitor::VisitOMPTileDirective(const OMPTileDirective *D) {
+  VisitOMPLoopTransformationDirective(D);
+}
+
 void EnqueueVisitor::VisitOMPUnrollDirective(const OMPUnrollDirective *D) {
-  VisitOMPLoopBasedDirective(D);
+  VisitOMPLoopTransformationDirective(D);
 }
 
 void EnqueueVisitor::VisitOMPForDirective(const OMPForDirective *D) {
Index: clang/lib/Serialization/ASTWriterStmt.cpp
===================================================================
--- clang/lib/Serialization/ASTWriterStmt.cpp
+++ clang/lib/Serialization/ASTWriterStmt.cpp
@@ -2223,13 +2223,18 @@
   Code = serialization::STMT_OMP_SIMD_DIRECTIVE;
 }
 
-void ASTStmtWriter::VisitOMPTileDirective(OMPTileDirective *D) {
+void ASTStmtWriter::VisitOMPLoopTransformationDirective(
+    OMPLoopTransformationDirective *D) {
   VisitOMPLoopBasedDirective(D);
+}
+
+void ASTStmtWriter::VisitOMPTileDirective(OMPTileDirective *D) {
+  VisitOMPLoopTransformationDirective(D);
   Code = serialization::STMT_OMP_TILE_DIRECTIVE;
 }
 
 void ASTStmtWriter::VisitOMPUnrollDirective(OMPUnrollDirective *D) {
-  VisitOMPLoopBasedDirective(D);
+  VisitOMPLoopTransformationDirective(D);
   Code = serialization::STMT_OMP_UNROLL_DIRECTIVE;
 }
 
Index: clang/lib/Serialization/ASTReaderStmt.cpp
===================================================================
--- clang/lib/Serialization/ASTReaderStmt.cpp
+++ clang/lib/Serialization/ASTReaderStmt.cpp
@@ -2324,12 +2324,17 @@
   VisitOMPLoopDirective(D);
 }
 
-void ASTStmtReader::VisitOMPTileDirective(OMPTileDirective *D) {
+void ASTStmtReader::VisitOMPLoopTransformationDirective(
+    OMPLoopTransformationDirective *D) {
   VisitOMPLoopBasedDirective(D);
 }
 
+void ASTStmtReader::VisitOMPTileDirective(OMPTileDirective *D) {
+  VisitOMPLoopTransformationDirective(D);
+}
+
 void ASTStmtReader::VisitOMPUnrollDirective(OMPUnrollDirective *D) {
-  VisitOMPLoopBasedDirective(D);
+  VisitOMPLoopTransformationDirective(D);
 }
 
 void ASTStmtReader::VisitOMPForDirective(OMPForDirective *D) {
Index: clang/lib/Sema/SemaOpenMP.cpp
===================================================================
--- clang/lib/Sema/SemaOpenMP.cpp
+++ clang/lib/Sema/SemaOpenMP.cpp
@@ -3823,13 +3823,8 @@
     VisitSubCaptures(S);
   }
 
-  void VisitOMPTileDirective(OMPTileDirective *S) {
-    // #pragma omp tile does not introduce data sharing.
-    VisitStmt(S);
-  }
-
-  void VisitOMPUnrollDirective(OMPUnrollDirective *S) {
-    // #pragma omp unroll does not introduce data sharing.
+  void VisitOMPLoopTransformationDirective(OMPLoopTransformationDirective *S) {
+    // Loop transformation directives do not introduce data sharing
     VisitStmt(S);
   }
 
@@ -9050,15 +9045,8 @@
             }
             return false;
           },
-          [&SemaRef, &Captures](OMPLoopBasedDirective *Transform) {
-            Stmt *DependentPreInits;
-            if (auto *Dir = dyn_cast<OMPTileDirective>(Transform)) {
-              DependentPreInits = Dir->getPreInits();
-            } else if (auto *Dir = dyn_cast<OMPUnrollDirective>(Transform)) {
-              DependentPreInits = Dir->getPreInits();
-            } else {
-              llvm_unreachable("Unexpected loop transformation");
-            }
+          [&SemaRef, &Captures](OMPLoopTransformationDirective *Transform) {
+            Stmt *DependentPreInits = Transform->getPreInits();
             if (!DependentPreInits)
               return;
             for (Decl *C : cast<DeclStmt>(DependentPreInits)->getDeclGroup()) {
Index: clang/lib/CodeGen/CGStmtOpenMP.cpp
===================================================================
--- clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -1829,9 +1829,7 @@
     return;
   }
   if (SimplifiedS == NextLoop) {
-    if (auto *Dir = dyn_cast<OMPTileDirective>(SimplifiedS))
-      SimplifiedS = Dir->getTransformedStmt();
-    if (auto *Dir = dyn_cast<OMPUnrollDirective>(SimplifiedS))
+    if (auto *Dir = dyn_cast<OMPLoopTransformationDirective>(SimplifiedS))
       SimplifiedS = Dir->getTransformedStmt();
     if (const auto *CanonLoop = dyn_cast<OMPCanonicalLoop>(SimplifiedS))
       SimplifiedS = CanonLoop->getLoopStmt();
Index: clang/lib/AST/StmtProfile.cpp
===================================================================
--- clang/lib/AST/StmtProfile.cpp
+++ clang/lib/AST/StmtProfile.cpp
@@ -915,12 +915,17 @@
   VisitOMPLoopDirective(S);
 }
 
-void StmtProfiler::VisitOMPTileDirective(const OMPTileDirective *S) {
+void StmtProfiler::VisitOMPLoopTransformationDirective(
+    const OMPLoopTransformationDirective *S) {
   VisitOMPLoopBasedDirective(S);
 }
 
+void StmtProfiler::VisitOMPTileDirective(const OMPTileDirective *S) {
+  VisitOMPLoopTransformationDirective(S);
+}
+
 void StmtProfiler::VisitOMPUnrollDirective(const OMPUnrollDirective *S) {
-  VisitOMPLoopBasedDirective(S);
+  VisitOMPLoopTransformationDirective(S);
 }
 
 void StmtProfiler::VisitOMPForDirective(const OMPForDirective *S) {
Index: clang/lib/AST/StmtOpenMP.cpp
===================================================================
--- clang/lib/AST/StmtOpenMP.cpp
+++ clang/lib/AST/StmtOpenMP.cpp
@@ -125,28 +125,25 @@
 bool OMPLoopBasedDirective::doForAllLoops(
     Stmt *CurStmt, bool TryImperfectlyNestedLoops, unsigned NumLoops,
     llvm::function_ref<bool(unsigned, Stmt *)> Callback,
-    llvm::function_ref<void(OMPLoopBasedDirective *)>
+    llvm::function_ref<void(OMPLoopTransformationDirective *)>
         OnTransformationCallback) {
   CurStmt = CurStmt->IgnoreContainers();
   for (unsigned Cnt = 0; Cnt < NumLoops; ++Cnt) {
     while (true) {
-      auto *OrigStmt = CurStmt;
-      if (auto *Dir = dyn_cast<OMPTileDirective>(OrigStmt)) {
-        OnTransformationCallback(Dir);
-        CurStmt = Dir->getTransformedStmt();
-      } else if (auto *Dir = dyn_cast<OMPUnrollDirective>(OrigStmt)) {
-        OnTransformationCallback(Dir);
-        CurStmt = Dir->getTransformedStmt();
-      } else {
+      auto *Dir = dyn_cast<OMPLoopTransformationDirective>(CurStmt);
+      if (!Dir)
         break;
-      }
 
-      if (!CurStmt) {
-        // May happen if the loop transformation does not result in a generated
-        // loop (such as full unrolling).
-        CurStmt = OrigStmt;
+      OnTransformationCallback(Dir);
+
+      Stmt *TransformedStmt = Dir->getTransformedStmt();
+      if (!TransformedStmt) {
+        // May happen if the loop transformation does not result in a
+        // generated loop (such as full unrolling).
         break;
       }
+
+      CurStmt = TransformedStmt;
     }
     if (auto *CanonLoop = dyn_cast<OMPCanonicalLoop>(CurStmt))
       CurStmt = CanonLoop->getLoopStmt();
@@ -363,6 +360,32 @@
   return Dir;
 }
 
+Stmt *OMPLoopTransformationDirective::getTransformedStmt() const {
+  switch (getStmtClass()) {
+#define STMT(CLASS, PARENT)
+#define ABSTRACT_STMT(CLASS)
+#define OMPLOOPTRANSFORMATIONDIRECTIVE(CLASS, PARENT)                          \
+  case Stmt::CLASS##Class:                                                     \
+    return static_cast<const CLASS *>(this)->getTransformedStmt();
+#include "clang/AST/StmtNodes.inc"
+  default:
+    llvm_unreachable("Not a loop transformation");
+  }
+}
+
+Stmt *OMPLoopTransformationDirective::getPreInits() const {
+  switch (getStmtClass()) {
+#define STMT(CLASS, PARENT)
+#define ABSTRACT_STMT(CLASS)
+#define OMPLOOPTRANSFORMATIONDIRECTIVE(CLASS, PARENT)                          \
+  case Stmt::CLASS##Class:                                                     \
+    return static_cast<const CLASS *>(this)->getPreInits();
+#include "clang/AST/StmtNodes.inc"
+  default:
+    llvm_unreachable("Not a loop transformation");
+  }
+}
+
 OMPForDirective *OMPForDirective::CreateEmpty(const ASTContext &C,
                                               unsigned NumClauses,
                                               unsigned CollapsedNum,
Index: clang/include/clang/Basic/StmtNodes.td
===================================================================
--- clang/include/clang/Basic/StmtNodes.td
+++ clang/include/clang/Basic/StmtNodes.td
@@ -224,8 +224,9 @@
 def OMPLoopDirective : StmtNode<OMPLoopBasedDirective, 1>;
 def OMPParallelDirective : StmtNode<OMPExecutableDirective>;
 def OMPSimdDirective : StmtNode<OMPLoopDirective>;
-def OMPTileDirective : StmtNode<OMPLoopBasedDirective>;
-def OMPUnrollDirective : StmtNode<OMPLoopBasedDirective>;
+def OMPLoopTransformationDirective : StmtNode<OMPLoopBasedDirective, 1>;
+def OMPTileDirective : StmtNode<OMPLoopTransformationDirective>;
+def OMPUnrollDirective : StmtNode<OMPLoopTransformationDirective>;
 def OMPForDirective : StmtNode<OMPLoopDirective>;
 def OMPForSimdDirective : StmtNode<OMPLoopDirective>;
 def OMPSectionsDirective : StmtNode<OMPExecutableDirective>;
Index: clang/include/clang/AST/StmtOpenMP.h
===================================================================
--- clang/include/clang/AST/StmtOpenMP.h
+++ clang/include/clang/AST/StmtOpenMP.h
@@ -889,22 +889,23 @@
 
   /// Calls the specified callback function for all the loops in \p CurStmt,
   /// from the outermost to the innermost.
-  static bool doForAllLoops(Stmt *CurStmt, bool TryImperfectlyNestedLoops,
-                            unsigned NumLoops,
-                            llvm::function_ref<bool(unsigned, Stmt *)> Callback,
-                            llvm::function_ref<void(OMPLoopBasedDirective *)>
-                                OnTransformationCallback);
+  static bool
+  doForAllLoops(Stmt *CurStmt, bool TryImperfectlyNestedLoops,
+                unsigned NumLoops,
+                llvm::function_ref<bool(unsigned, Stmt *)> Callback,
+                llvm::function_ref<void(OMPLoopTransformationDirective *)>
+                    OnTransformationCallback);
   static bool
   doForAllLoops(const Stmt *CurStmt, bool TryImperfectlyNestedLoops,
                 unsigned NumLoops,
                 llvm::function_ref<bool(unsigned, const Stmt *)> Callback,
-                llvm::function_ref<void(const OMPLoopBasedDirective *)>
+                llvm::function_ref<void(const OMPLoopTransformationDirective *)>
                     OnTransformationCallback) {
     auto &&NewCallback = [Callback](unsigned Cnt, Stmt *CurStmt) {
       return Callback(Cnt, CurStmt);
     };
     auto &&NewTransformCb =
-        [OnTransformationCallback](OMPLoopBasedDirective *A) {
+        [OnTransformationCallback](OMPLoopTransformationDirective *A) {
           OnTransformationCallback(A);
         };
     return doForAllLoops(const_cast<Stmt *>(CurStmt), TryImperfectlyNestedLoops,
@@ -917,7 +918,7 @@
   doForAllLoops(Stmt *CurStmt, bool TryImperfectlyNestedLoops,
                 unsigned NumLoops,
                 llvm::function_ref<bool(unsigned, Stmt *)> Callback) {
-    auto &&TransformCb = [](OMPLoopBasedDirective *) {};
+    auto &&TransformCb = [](OMPLoopTransformationDirective *) {};
     return doForAllLoops(CurStmt, TryImperfectlyNestedLoops, NumLoops, Callback,
                          TransformCb);
   }
@@ -954,6 +955,38 @@
   }
 };
 
+/// The base class for all loop transformation directives.
+class OMPLoopTransformationDirective : public OMPLoopBasedDirective {
+  friend class ASTStmtReader;
+
+protected:
+  explicit OMPLoopTransformationDirective(StmtClass SC,
+                                          OpenMPDirectiveKind Kind,
+                                          SourceLocation StartLoc,
+                                          SourceLocation EndLoc,
+                                          unsigned NumAssociatedLoops)
+      : OMPLoopBasedDirective(SC, Kind, StartLoc, EndLoc, NumAssociatedLoops) {}
+
+public:
+  /// Return the number of associated (consumed) loops.
+  unsigned getNumAssociatedLoops() const { return getLoopsNumber(); }
+
+  /// Get the de-sugared statements after after the loop transformation.
+  ///
+  /// Might be nullptr if either the directive generates no loops and is handled
+  /// directly in CodeGen, or resolving a template-dependence context is
+  /// required.
+  Stmt *getTransformedStmt() const;
+
+  /// Return preinits statement.
+  Stmt *getPreInits() const;
+
+  static bool classof(const Stmt *T) {
+    return T->getStmtClass() == OMPTileDirectiveClass ||
+           T->getStmtClass() == OMPUnrollDirectiveClass;
+  }
+};
+
 /// This is a common base class for loop directives ('omp simd', 'omp
 /// for', 'omp for simd' etc.). It is responsible for the loop code generation.
 ///
@@ -5011,7 +5044,7 @@
 };
 
 /// This represents the '#pragma omp tile' loop transformation directive.
-class OMPTileDirective final : public OMPLoopBasedDirective {
+class OMPTileDirective final : public OMPLoopTransformationDirective {
   friend class ASTStmtReader;
   friend class OMPExecutableDirective;
 
@@ -5023,8 +5056,9 @@
 
   explicit OMPTileDirective(SourceLocation StartLoc, SourceLocation EndLoc,
                             unsigned NumLoops)
-      : OMPLoopBasedDirective(OMPTileDirectiveClass, llvm::omp::OMPD_tile,
-                              StartLoc, EndLoc, NumLoops) {}
+      : OMPLoopTransformationDirective(OMPTileDirectiveClass,
+                                       llvm::omp::OMPD_tile, StartLoc, EndLoc,
+                                       NumLoops) {}
 
   void setPreInits(Stmt *PreInits) {
     Data->getChildren()[PreInitsOffset] = PreInits;
@@ -5061,8 +5095,6 @@
   static OMPTileDirective *CreateEmpty(const ASTContext &C, unsigned NumClauses,
                                        unsigned NumLoops);
 
-  unsigned getNumAssociatedLoops() const { return getLoopsNumber(); }
-
   /// Gets/sets the associated loops after tiling.
   ///
   /// This is in de-sugared format stored as a CompoundStmt.
@@ -5092,7 +5124,7 @@
 /// #pragma omp unroll
 /// for (int i = 0; i < 64; ++i)
 /// \endcode
-class OMPUnrollDirective final : public OMPLoopBasedDirective {
+class OMPUnrollDirective final : public OMPLoopTransformationDirective {
   friend class ASTStmtReader;
   friend class OMPExecutableDirective;
 
@@ -5103,8 +5135,9 @@
   };
 
   explicit OMPUnrollDirective(SourceLocation StartLoc, SourceLocation EndLoc)
-      : OMPLoopBasedDirective(OMPUnrollDirectiveClass, llvm::omp::OMPD_unroll,
-                              StartLoc, EndLoc, 1) {}
+      : OMPLoopTransformationDirective(OMPUnrollDirectiveClass,
+                                       llvm::omp::OMPD_unroll, StartLoc, EndLoc,
+                                       1) {}
 
   /// Set the pre-init statements.
   void setPreInits(Stmt *PreInits) {
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to