This revision was automatically updated to reflect the committed changes.
Closed by commit rG2130117f92e5: [Clang][OpenMP] Allow loop-transformations 
with template parameters. (authored by Meinersbur).

Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D111124/new/

https://reviews.llvm.org/D111124

Files:
  clang/include/clang/AST/StmtOpenMP.h
  clang/lib/AST/StmtOpenMP.cpp
  clang/lib/Sema/SemaOpenMP.cpp
  clang/lib/Serialization/ASTReaderStmt.cpp
  clang/lib/Serialization/ASTWriterStmt.cpp
  clang/test/OpenMP/tile_ast_print.cpp
  clang/test/OpenMP/unroll_ast_print.cpp

Index: clang/test/OpenMP/unroll_ast_print.cpp
===================================================================
--- clang/test/OpenMP/unroll_ast_print.cpp
+++ clang/test/OpenMP/unroll_ast_print.cpp
@@ -124,4 +124,26 @@
   unroll_templated<int,0,1024,1,4>();
 }
 
+
+// PRINT-LABEL: template <int Factor> void unroll_templated_factor(int start, int stop, int step) {
+// DUMP-LABEL:  FunctionTemplateDecl {{.*}} unroll_templated_factor
+template <int Factor>
+void unroll_templated_factor(int start, int stop, int step) {
+  // PRINT: #pragma omp unroll partial(Factor)
+  // DUMP:      OMPUnrollDirective
+  // DUMP-NEXT: OMPPartialClause
+  // DUMP-NEXT:   DeclRefExpr {{.*}} 'Factor' 'int'
+  #pragma omp unroll partial(Factor)
+    // PRINT-NEXT: for (int i = start; i < stop; i += step)
+    // DUMP-NEXT:  ForStmt
+    for (int i = start; i < stop; i += step)
+      // PRINT-NEXT: body(i);
+      // DUMP:  CallExpr
+      body(i);
+}
+void unroll_template_factor() {
+  unroll_templated_factor<4>(0, 42, 2);
+}
+
+
 #endif
Index: clang/test/OpenMP/tile_ast_print.cpp
===================================================================
--- clang/test/OpenMP/tile_ast_print.cpp
+++ clang/test/OpenMP/tile_ast_print.cpp
@@ -162,4 +162,25 @@
 }
 
 
+// PRINT-LABEL: template <int Tile> void foo7(int start, int stop, int step) {
+// DUMP-LABEL: FunctionTemplateDecl {{.*}} foo7
+template <int Tile>
+void foo7(int start, int stop, int step) {
+  // PRINT: #pragma omp tile sizes(Tile)
+  // DUMP:      OMPTileDirective
+  // DUMP-NEXT:   OMPSizesClause
+  // DUMP-NEXT:     DeclRefExpr {{.*}} 'Tile' 'int'
+  #pragma omp tile sizes(Tile)
+    // PRINT-NEXT:  for (int i = start; i < stop; i += step)
+    // DUMP-NEXT: ForStmt
+    for (int i = start; i < stop; i += step)
+      // PRINT-NEXT: body(i);
+      // DUMP:  CallExpr
+      body(i);
+}
+void tfoo7() {
+  foo7<5>(0, 42, 2);
+}
+
+
 #endif
Index: clang/lib/Serialization/ASTWriterStmt.cpp
===================================================================
--- clang/lib/Serialization/ASTWriterStmt.cpp
+++ clang/lib/Serialization/ASTWriterStmt.cpp
@@ -2226,6 +2226,7 @@
 void ASTStmtWriter::VisitOMPLoopTransformationDirective(
     OMPLoopTransformationDirective *D) {
   VisitOMPLoopBasedDirective(D);
+  Record.writeUInt32(D->getNumGeneratedLoops());
 }
 
 void ASTStmtWriter::VisitOMPTileDirective(OMPTileDirective *D) {
Index: clang/lib/Serialization/ASTReaderStmt.cpp
===================================================================
--- clang/lib/Serialization/ASTReaderStmt.cpp
+++ clang/lib/Serialization/ASTReaderStmt.cpp
@@ -2327,6 +2327,7 @@
 void ASTStmtReader::VisitOMPLoopTransformationDirective(
     OMPLoopTransformationDirective *D) {
   VisitOMPLoopBasedDirective(D);
+  D->setNumGeneratedLoops(Record.readUInt32());
 }
 
 void ASTStmtReader::VisitOMPTileDirective(OMPTileDirective *D) {
Index: clang/lib/Sema/SemaOpenMP.cpp
===================================================================
--- clang/lib/Sema/SemaOpenMP.cpp
+++ clang/lib/Sema/SemaOpenMP.cpp
@@ -12919,10 +12919,12 @@
                                   Body, OriginalInits))
     return StmtError();
 
+  unsigned NumGeneratedLoops = PartialClause ? 1 : 0;
+
   // Delay unrolling to when template is completely instantiated.
   if (CurContext->isDependentContext())
     return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
-                                      nullptr, nullptr);
+                                      NumGeneratedLoops, nullptr, nullptr);
 
   OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers.front();
 
@@ -12941,9 +12943,9 @@
   // The generated loop may only be passed to other loop-associated directive
   // when a partial clause is specified. Without the requirement it is
   // sufficient to generate loop unroll metadata at code-generation.
-  if (!PartialClause)
+  if (NumGeneratedLoops == 0)
     return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
-                                      nullptr, nullptr);
+                                      NumGeneratedLoops, nullptr, nullptr);
 
   // Otherwise, we need to provide a de-sugared/transformed AST that can be
   // associated with another loop directive.
@@ -13164,7 +13166,8 @@
               LoopHelper.Init->getBeginLoc(), LoopHelper.Inc->getEndLoc());
 
   return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
-                                    OuterFor, buildPreInits(Context, PreInits));
+                                    NumGeneratedLoops, OuterFor,
+                                    buildPreInits(Context, PreInits));
 }
 
 OMPClause *Sema::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind, Expr *Expr,
Index: clang/lib/AST/StmtOpenMP.cpp
===================================================================
--- clang/lib/AST/StmtOpenMP.cpp
+++ clang/lib/AST/StmtOpenMP.cpp
@@ -138,9 +138,18 @@
 
       Stmt *TransformedStmt = Dir->getTransformedStmt();
       if (!TransformedStmt) {
-        // May happen if the loop transformation does not result in a
-        // generated loop (such as full unrolling).
-        break;
+        unsigned NumGeneratedLoops = Dir->getNumGeneratedLoops();
+        if (NumGeneratedLoops == 0) {
+          // May happen if the loop transformation does not result in a
+          // generated loop (such as full unrolling).
+          break;
+        }
+        if (NumGeneratedLoops > 0) {
+          // The loop transformation construct has generated loops, but these
+          // may not have been generated yet due to being in a dependent
+          // context.
+          return true;
+        }
       }
 
       CurStmt = TransformedStmt;
@@ -419,10 +428,13 @@
 OMPUnrollDirective *
 OMPUnrollDirective::Create(const ASTContext &C, SourceLocation StartLoc,
                            SourceLocation EndLoc, ArrayRef<OMPClause *> Clauses,
-                           Stmt *AssociatedStmt, Stmt *TransformedStmt,
-                           Stmt *PreInits) {
+                           Stmt *AssociatedStmt, unsigned NumGeneratedLoops,
+                           Stmt *TransformedStmt, Stmt *PreInits) {
+  assert(NumGeneratedLoops <= 1 && "Unrolling generates at most one loop");
+
   auto *Dir = createDirective<OMPUnrollDirective>(
       C, Clauses, AssociatedStmt, TransformedStmtOffset + 1, StartLoc, EndLoc);
+  Dir->setNumGeneratedLoops(NumGeneratedLoops);
   Dir->setTransformedStmt(TransformedStmt);
   Dir->setPreInits(PreInits);
   return Dir;
Index: clang/include/clang/AST/StmtOpenMP.h
===================================================================
--- clang/include/clang/AST/StmtOpenMP.h
+++ clang/include/clang/AST/StmtOpenMP.h
@@ -959,6 +959,9 @@
 class OMPLoopTransformationDirective : public OMPLoopBasedDirective {
   friend class ASTStmtReader;
 
+  /// Number of loops generated by this loop transformation.
+  unsigned NumGeneratedLoops = 0;
+
 protected:
   explicit OMPLoopTransformationDirective(StmtClass SC,
                                           OpenMPDirectiveKind Kind,
@@ -967,10 +970,16 @@
                                           unsigned NumAssociatedLoops)
       : OMPLoopBasedDirective(SC, Kind, StartLoc, EndLoc, NumAssociatedLoops) {}
 
+  /// Set the number of loops generated by this loop transformation.
+  void setNumGeneratedLoops(unsigned Num) { NumGeneratedLoops = Num; }
+
 public:
   /// Return the number of associated (consumed) loops.
   unsigned getNumAssociatedLoops() const { return getLoopsNumber(); }
 
+  /// Return the number of loops generated by this loop transformation.
+  unsigned getNumGeneratedLoops() { return NumGeneratedLoops; }
+
   /// Get the de-sugared statements after after the loop transformation.
   ///
   /// Might be nullptr if either the directive generates no loops and is handled
@@ -5058,7 +5067,9 @@
                             unsigned NumLoops)
       : OMPLoopTransformationDirective(OMPTileDirectiveClass,
                                        llvm::omp::OMPD_tile, StartLoc, EndLoc,
-                                       NumLoops) {}
+                                       NumLoops) {
+    setNumGeneratedLoops(3 * NumLoops);
+  }
 
   void setPreInits(Stmt *PreInits) {
     Data->getChildren()[PreInitsOffset] = PreInits;
@@ -5163,7 +5174,7 @@
   static OMPUnrollDirective *
   Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
          ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt,
-         Stmt *TransformedStmt, Stmt *PreInits);
+         unsigned NumGeneratedLoops, Stmt *TransformedStmt, Stmt *PreInits);
 
   /// Build an empty '#pragma omp unroll' AST node for deserialization.
   ///
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to