Meinersbur created this revision.
Meinersbur added reviewers: hfinkel, kbarton, SjoerdMeijer, aaron.ballman, 
ABataev, fhahn, hsaito, hans, greened, dmgreen, Ayal, asavonic, rtrieu, dorit, 
rsmith, tyler.nowicki, jdoerfert.
Herald added subscribers: cfe-commits, zzheng.
Herald added a project: clang.
Meinersbur removed subscribers: zzheng, llvm-commits.
Meinersbur added a subscriber: zzheng.

De(-serialization) of #pragma clang transform AST nodes and clauses.

For a full description, see D69088 <https://reviews.llvm.org/D69088>.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D70572

Files:
  clang/include/clang/Serialization/ASTBitCodes.h
  clang/include/clang/Serialization/ASTReader.h
  clang/include/clang/Serialization/ASTWriter.h
  clang/lib/AST/StmtProfile.cpp
  clang/lib/Serialization/ASTReader.cpp
  clang/lib/Serialization/ASTReaderStmt.cpp
  clang/lib/Serialization/ASTWriter.cpp
  clang/lib/Serialization/ASTWriterStmt.cpp
  clang/test/PCH/transform-interleave.cpp
  clang/test/PCH/transform-unroll.cpp
  clang/test/PCH/transform-unrollandjam.cpp
  clang/test/PCH/transform-vectorize.cpp

Index: clang/test/PCH/transform-vectorize.cpp
===================================================================
--- /dev/null
+++ clang/test/PCH/transform-vectorize.cpp
@@ -0,0 +1,28 @@
+// RUN: %clang_cc1 -triple x86_64-unknown-unknown -fexperimental-transform-pragma -emit-pch -o %t.pch %s
+// RUN: %clang_cc1 -triple x86_64-unknown-unknown -fexperimental-transform-pragma -include-pch %t.pch %s -ast-dump-all -o - | FileCheck %s --dump-input=fail -vv
+
+#ifndef HEADER
+#define HEADER
+
+void vectorize_heuristic(int n) {
+#pragma clang transform vectorize
+  for (int i = 0; i < n; i+=1)
+    ;
+}
+// CHECK-LABEL: FunctionDecl {{.*}} imported vectorize_heuristic
+// CHECK: TransformExecutableDirective
+// CHECK-NEXT: ForStmt
+
+
+void vectorize_width(int n) {
+#pragma clang transform vectorize width(4)
+  for (int i = 0; i < n; i+=1)
+    ;
+}
+// CHECK-LABEL: FunctionDecl {{.*}} imported vectorize_width
+// CHECK: TransformExecutableDirective
+// CHECK-NEXT: WidthClause
+// CHECK-NEXT:   IntegerLiteral {{.*}} 'int' 4
+// CHECK-NEXT: ForStmt
+
+#endif /* HEADER */
Index: clang/test/PCH/transform-unrollandjam.cpp
===================================================================
--- /dev/null
+++ clang/test/PCH/transform-unrollandjam.cpp
@@ -0,0 +1,30 @@
+// RUN: %clang_cc1 -triple x86_64-unknown-unknown -fexperimental-transform-pragma -emit-pch -o %t.pch %s
+// RUN: %clang_cc1 -triple x86_64-unknown-unknown -fexperimental-transform-pragma -include-pch %t.pch %s -ast-dump-all -o - | FileCheck %s --dump-input=fail -vv
+
+#ifndef HEADER
+#define HEADER
+
+void  unrollandjam_heuristic(int n) {
+#pragma clang transform unrollandjam
+  for (int i = 0; i < n; i+=1)
+    for (int j = 0; j < n; j+=1)
+      ;
+}
+// CHECK-LABEL: FunctionDecl {{.*}} imported unrollandjam_heuristic
+// CHECK: TransformExecutableDirective
+// CHECK-NEXT: ForStmt
+
+
+void unrollandjam_partial(int n) {
+#pragma clang transform unrollandjam partial(4)
+  for (int i = 0; i < n; i+=1)
+    for (int j = 0; j < n; j+=1)
+      ;
+}
+// CHECK-LABEL: FunctionDecl {{.*}} imported unrollandjam_partial
+// CHECK: TransformExecutableDirective
+// CHECK-NEXT: PartialClause
+// CHECK-NEXT:   IntegerLiteral {{.*}} 'int' 4
+// CHECK-NEXT: ForStmt
+
+#endif /* HEADER */
Index: clang/test/PCH/transform-unroll.cpp
===================================================================
--- /dev/null
+++ clang/test/PCH/transform-unroll.cpp
@@ -0,0 +1,85 @@
+// RUN: %clang_cc1 -triple x86_64-unknown-unknown -fexperimental-transform-pragma -emit-pch -o %t.pch %s
+// RUN: %clang_cc1 -triple x86_64-unknown-unknown -fexperimental-transform-pragma -include-pch %t.pch %s -ast-dump-all -o - | FileCheck %s --dump-input=fail -vv
+
+#ifndef HEADER
+#define HEADER
+
+void  unroll_heuristic(int n) {
+#pragma clang transform unroll
+  for (int i = 0; i < 4; i+=1)
+    ;
+}
+// CHECK-LABEL: FunctionDecl {{.*}} imported unroll_heuristic
+// CHECK: TransformExecutableDirective
+// CHECK-NEXT: ForStmt
+
+
+void unroll_full(int n) {
+#pragma clang transform unroll full
+  for (int i = 0; i < 4; i+=1)
+    ;
+}
+// CHECK-LABEL: FunctionDecl {{.*}} imported unroll_full
+// CHECK: TransformExecutableDirective
+// CHECK-NEXT: FullClause
+// CHECK-NEXT: ForStmt
+
+
+void unroll_partial(int n) {
+#pragma clang transform unroll partial(4)
+  for (int i = 0; i < n; i+=1)
+    ;
+}
+// CHECK-LABEL: FunctionDecl {{.*}} imported unroll_partial
+// CHECK: TransformExecutableDirective
+// CHECK-NEXT: PartialClause
+// CHECK-NEXT:   IntegerLiteral {{.*}} 'int' 4
+// CHECK-NEXT: ForStmt
+
+
+template<int FACTOR>
+void unroll_template_function(int n) {
+#pragma clang transform unroll partial(FACTOR)
+  for (int i = 0; i < n; i+=1)
+    ;
+}
+// CHECK-LABEL: FunctionDecl {{.*}} imported unroll_template_function
+// CHECK: TransformExecutableDirective
+// CHECK-NEXT: PartialClause
+// CHECK-NEXT:   DeclRefExpr {{.*}} 'FACTOR' 'int'
+// CHECK-NEXT: ForStmt
+
+
+template void unroll_template_function<5>(int);
+// CHECK-LABEL: FunctionDecl {{.*}} imported unroll_template_function
+// CHECK: TransformExecutableDirective
+// CHECK-NEXT: PartialClause
+// CHECK-NEXT:   SubstNonTypeTemplateParmExpr
+// CHECK-NEXT:   IntegerLiteral {{.*}} 'int' 5
+// CHECK-NEXT: ForStmt
+
+
+template<int FACTOR>
+struct Unroll {
+  void unroll_template_method(int n) {
+#pragma clang transform unroll partial(FACTOR)
+    for (int i = 0; i < n; i+=1)
+      ;
+  }
+};
+// CHECK-LABEL: CXXMethodDecl {{.*}} imported unroll_template_method
+// CHECK: TransformExecutableDirective
+// CHECK-NEXT: PartialClause
+// CHECK-NEXT:   DeclRefExpr {{.*}} 'FACTOR' 'int'
+// CHECK-NEXT: ForStmt
+
+
+template struct Unroll<6>;
+// CHECK-LABEL: CXXMethodDecl {{.*}} imported unroll_template_method
+// CHECK: TransformExecutableDirective
+// CHECK-NEXT: PartialClause
+// CHECK-NEXT:   SubstNonTypeTemplateParmExpr
+// CHECK-NEXT:   IntegerLiteral {{.*}} 'int' 6
+// CHECK-NEXT: ForStmt
+
+#endif /* HEADER */
Index: clang/test/PCH/transform-interleave.cpp
===================================================================
--- /dev/null
+++ clang/test/PCH/transform-interleave.cpp
@@ -0,0 +1,28 @@
+// RUN: %clang_cc1 -triple x86_64-unknown-unknown -fexperimental-transform-pragma -emit-pch -o %t.pch %s
+// RUN: %clang_cc1 -triple x86_64-unknown-unknown -fexperimental-transform-pragma -include-pch %t.pch %s -ast-dump-all -o - | FileCheck %s --dump-input=fail -vv
+
+#ifndef HEADER
+#define HEADER
+
+void interleave_heuristic(int n) {
+#pragma clang transform interleave
+  for (int i = 0; i < n; i+=1)
+    ;
+}
+// CHECK-LABEL: FunctionDecl {{.*}} imported interleave_heuristic
+// CHECK: TransformExecutableDirective
+// CHECK-NEXT: ForStmt
+
+
+void interleave_factor(int n) {
+#pragma clang transform interleave factor(4)
+  for (int i = 0; i < n; i+=1)
+    ;
+}
+// CHECK-LABEL: FunctionDecl {{.*}} imported interleave_factor
+// CHECK: TransformExecutableDirective
+// CHECK-NEXT: FactorClause
+// CHECK-NEXT:   IntegerLiteral {{.*}} 'int' 4
+// CHECK-NEXT: ForStmt
+
+#endif /* HEADER */
Index: clang/lib/Serialization/ASTWriterStmt.cpp
===================================================================
--- clang/lib/Serialization/ASTWriterStmt.cpp
+++ clang/lib/Serialization/ASTWriterStmt.cpp
@@ -1953,13 +1953,45 @@
   Record.AddSourceLocation(S->getLeaveLoc());
   Code = serialization::STMT_SEH_LEAVE;
 }
+
 //===----------------------------------------------------------------------===//
 // Transformation Directives.
 //===----------------------------------------------------------------------===//
 
 void ASTStmtWriter::VisitTransformExecutableDirective(
     TransformExecutableDirective *D) {
-  llvm_unreachable("not implemented");
+  Code = serialization::STMT_TRANSFORM_EXECUTABLE_DIRECTIVE;
+  VisitStmt(D);
+  Record.push_back(D->getNumClauses());
+
+  Record.AddSourceRange(D->getRange());
+  TransformClauseWriter ClauseWriter(Record);
+  for (auto C : D->clauses())
+    ClauseWriter.writeClause(C);
+
+#if 0
+  {
+    Record.push_back(C->getKind());
+    Record.AddSourceRange(C->getRange());
+    switch (C->getKind()) {
+    case TransformClause::UnknownKind:
+      llvm_unreachable("Cannot write unknown clause");
+    case TransformClause::FullKind:
+      break;
+    case TransformClause::FactorKind:
+      Record.AddStmt(static_cast<FactorClause *>(C)->getFactor());
+      break;
+    case TransformClause::WidthKind:
+      Record.AddStmt(static_cast<WidthClause *>(C)->getWidth());
+      break;
+    case TransformClause::PartialKind:
+            Record.AddStmt(static_cast<PartialClause *>(C)->getFactor());
+      break;
+    }
+  }
+#endif
+
+  Record.AddStmt(D->getAssociated());
 }
 
 //===----------------------------------------------------------------------===//
Index: clang/lib/Serialization/ASTWriter.cpp
===================================================================
--- clang/lib/Serialization/ASTWriter.cpp
+++ clang/lib/Serialization/ASTWriter.cpp
@@ -7127,3 +7127,29 @@
   Record.AddSourceLocation(C->getLParenLoc());
   Record.AddSourceLocation(C->getAtomicDefaultMemOrderKindKwLoc());
 }
+
+void TransformClauseWriter::writeClause(const TransformClause *C) {
+  Record.push_back(C->getKind());
+  Record.AddSourceRange(C->getRange());
+  Visit(C);
+}
+
+//===----------------------------------------------------------------------===//
+// TransformClause Serialization
+//===----------------------------------------------------------------------===//
+
+void TransformClauseWriter::VisitFullClause(const FullClause *C) {
+  // The full clause has no arguments.
+}
+
+void TransformClauseWriter::VisitPartialClause(const PartialClause *C) {
+  Record.AddStmt(C->getFactor());
+}
+
+void TransformClauseWriter::VisitWidthClause(const WidthClause *C) {
+  Record.AddStmt(C->getWidth());
+}
+
+void TransformClauseWriter::VisitFactorClause(const FactorClause *C) {
+  Record.AddStmt(C->getFactor());
+}
Index: clang/lib/Serialization/ASTReaderStmt.cpp
===================================================================
--- clang/lib/Serialization/ASTReaderStmt.cpp
+++ clang/lib/Serialization/ASTReaderStmt.cpp
@@ -2014,7 +2014,22 @@
 
 void ASTStmtReader::VisitTransformExecutableDirective(
     TransformExecutableDirective *D) {
-  llvm_unreachable("not implemented");
+  VisitStmt(D);
+  unsigned NumClauses = Record.readInt();
+  assert(D->getNumClauses() == NumClauses);
+  // The binary layout up to here is also assumed by
+  // ASTReader::ReadStmtFromStream and must be kept in-sync.
+
+  D->setRange(ReadSourceRange());
+
+  SmallVector<TransformClause *, 8> Clauses;
+  Clauses.reserve(NumClauses);
+  TransformClauseReader ClauseReader(Record);
+  for (unsigned i = 0; i < NumClauses; ++i)
+    Clauses.push_back(ClauseReader.readClause());
+  D->setClauses(Clauses);
+
+  D->setAssociated(Record.readSubStmt());
 }
 
 //===----------------------------------------------------------------------===//
@@ -2495,6 +2510,10 @@
       return nullptr;
     }
     switch ((StmtCode)MaybeStmtCode.get()) {
+    default:
+      llvm_unreachable("Unexpected statement type");
+      break;
+
     case STMT_STOP:
       Finished = true;
       break;
@@ -2942,6 +2961,11 @@
                                               nullptr);
       break;
 
+    case STMT_TRANSFORM_EXECUTABLE_DIRECTIVE:
+      S = TransformExecutableDirective::createEmpty(
+          Context, Record[ASTStmtReader::NumStmtFields]);
+      break;
+
     case STMT_OMP_PARALLEL_DIRECTIVE:
       S =
         OMPParallelDirective::CreateEmpty(Context,
@@ -3560,7 +3584,6 @@
       unsigned numTemplateArgs = Record[ASTStmtReader::NumExprFields];
       S = ConceptSpecializationExpr::Create(Context, Empty, numTemplateArgs);
       break;
-      
     }
 
     // We hit a STMT_STOP, so we're done with this expression.
Index: clang/lib/Serialization/ASTReader.cpp
===================================================================
--- clang/lib/Serialization/ASTReader.cpp
+++ clang/lib/Serialization/ASTReader.cpp
@@ -13244,3 +13244,55 @@
   }
   C->setComponents(Components, ListSizes);
 }
+
+//===----------------------------------------------------------------------===//
+// TransformClauseReader implementation
+//===----------------------------------------------------------------------===//
+
+TransformClause *TransformClauseReader::readClause() {
+  uint64_t Kind = Record.readInt();
+  SourceRange Range = Record.readSourceRange();
+
+  switch (Kind) {
+#define TRANSFORM_CLAUSE(Keyword, Name)                                        \
+  case TransformClause::Kind::Name##Kind:                                      \
+    return read##Name##Clause(Range);
+#include "clang/AST/TransformClauseKinds.def"
+#if 0
+  case TransformClause::Kind::FullKind:
+    Clause = FullClause::createEmpty(Context);
+    break;
+  case TransformClause::Kind::PartialKind: {
+    auto* C = PartialClause::createEmpty(Context);
+    C->setFactor(Record.readExpr());
+  } break;
+      case TransformClause::Kind::WidthKind:
+    C = WidthClause::createEmpty(Context);
+    break;
+          case TransformClause::Kind::FactorKind:
+    C = FactorClause::createEmpty(Context);
+    break;
+#endif
+  default:
+    llvm_unreachable("Unknown transform clause kind");
+  }
+}
+
+FullClause *TransformClauseReader::readFullClause(SourceRange Range) {
+  return FullClause::create(Context, Range);
+}
+
+PartialClause *TransformClauseReader::readPartialClause(SourceRange Range) {
+  Expr *Factor = Record.readExpr();
+  return PartialClause::create(Context, Range, Factor);
+}
+
+WidthClause *TransformClauseReader::readWidthClause(SourceRange Range) {
+  Expr *Width = Record.readExpr();
+  return WidthClause::create(Context, Range, Width);
+}
+
+FactorClause *TransformClauseReader::readFactorClause(SourceRange Range) {
+  Expr *Factor = Record.readExpr();
+  return FactorClause::create(Context, Range, Factor);
+}
Index: clang/lib/AST/StmtProfile.cpp
===================================================================
--- clang/lib/AST/StmtProfile.cpp
+++ clang/lib/AST/StmtProfile.cpp
@@ -771,9 +771,45 @@
 }
 }
 
+namespace {
+class TransformClauseProfiler
+    : public ConstTransformClauseVisitor<TransformClauseProfiler> {
+  StmtProfiler *Profiler;
+
+public:
+  TransformClauseProfiler(StmtProfiler *P) : Profiler(P) {}
+
+#define TRANSFORM_CLAUSE(Keyword, Name)                                        \
+  void Visit##Name##Clause(const Name##Clause *);
+#include "clang/AST/TransformClauseKinds.def"
+};
+
+void TransformClauseProfiler::VisitFullClause(const FullClause *C) {
+  // The full clause has no arguments.
+}
+
+void TransformClauseProfiler::VisitPartialClause(const PartialClause *C) {
+  Profiler->VisitExpr(C->getFactor());
+}
+
+void TransformClauseProfiler::VisitWidthClause(const WidthClause *C) {
+  Profiler->VisitExpr(C->getWidth());
+}
+
+void TransformClauseProfiler::VisitFactorClause(const FactorClause *C) {
+  Profiler->VisitExpr(C->getFactor());
+}
+} // namespace
+
 void StmtProfiler::VisitTransformExecutableDirective(
     const TransformExecutableDirective *S) {
   VisitStmt(S);
+  TransformClauseProfiler P(this);
+  for (TransformClause *C : S->clauses()) {
+    if (!C)
+      continue;
+    P.Visit(C);
+  }
 }
 
 void
Index: clang/include/clang/Serialization/ASTWriter.h
===================================================================
--- clang/include/clang/Serialization/ASTWriter.h
+++ clang/include/clang/Serialization/ASTWriter.h
@@ -19,6 +19,7 @@
 #include "clang/AST/DeclarationName.h"
 #include "clang/AST/NestedNameSpecifier.h"
 #include "clang/AST/OpenMPClause.h"
+#include "clang/AST/StmtTransform.h"
 #include "clang/AST/TemplateBase.h"
 #include "clang/AST/TemplateName.h"
 #include "clang/AST/Type.h"
@@ -1013,6 +1014,24 @@
   void VisitOMPClauseWithPostUpdate(OMPClauseWithPostUpdate *C);
 };
 
+class TransformClauseWriter
+    : public ConstTransformClauseVisitor<TransformClauseWriter> {
+  ASTRecordWriter &Record;
+
+public:
+  TransformClauseWriter(ASTRecordWriter &Record) : Record(Record) {}
+
+  void writeClause(const TransformClause *C);
+
+#define TRANSFORM_CLAUSE(Keyword, Name)                                        \
+  void Visit##Name##Clause(const Name##Clause *);
+#include "clang/AST/TransformClauseKinds.def"
+
+  void VisitTransformClause(const TransformClause *C) {
+    llvm_unreachable("Serialization of this clause not implemented");
+  }
+};
+
 } // namespace clang
 
 #endif // LLVM_CLANG_SERIALIZATION_ASTWRITER_H
Index: clang/include/clang/Serialization/ASTReader.h
===================================================================
--- clang/include/clang/Serialization/ASTReader.h
+++ clang/include/clang/Serialization/ASTReader.h
@@ -2719,6 +2719,51 @@
   void VisitOMPClauseWithPostUpdate(OMPClauseWithPostUpdate *C);
 };
 
+class TransformClauseReader {
+  ASTRecordReader &Record;
+  ASTContext &Context;
+
+public:
+  TransformClauseReader(ASTRecordReader &Record)
+      : Record(Record), Context(Record.getContext()) {}
+
+  TransformClause *readClause();
+
+#define TRANSFORM_CLAUSE(Keyword, Name)                                        \
+  Name##Clause *read##Name##Clause(SourceRange);
+#include "clang/AST/TransformClauseKinds.def"
+};
+
+#if 0
+
+class TransformClauseReader : public TransformClauseVisitor<TransformClauseReader> {
+  ASTRecordReader &Record;
+  ASTContext &Context;
+
+public:
+  TransformClauseReader(ASTRecordReader &Record) : Record(Record), Context(Record.getContext()) {}
+
+  TransformClause* readClause() {
+      uint64_t Kind = Record.readInt();
+SourceRange Range =   Record.readSourceRange();
+
+switch (Kind) {
+#define TRANSFORM_CLAUSE(Keyword, Name) case TransformClause::
+#include "clang/AST/TransformClauseKinds.def"
+}
+  }
+
+#define TRANSFORM_CLAUSE(Keyword, Name)                                        \
+  void Visit##Name##Clause(Name##Clause *);
+#include "clang/AST/TransformClauseKinds.def"
+
+    void VisitTransformClause( TransformClause * C) {
+      llvm_unreachable("Serialization of this clause not implemented");
+    }
+};
+
+#endif
+
 } // namespace clang
 
 #endif // LLVM_CLANG_SERIALIZATION_ASTREADER_H
Index: clang/include/clang/Serialization/ASTBitCodes.h
===================================================================
--- clang/include/clang/Serialization/ASTBitCodes.h
+++ clang/include/clang/Serialization/ASTBitCodes.h
@@ -1939,6 +1939,9 @@
       STMT_SEH_FINALLY,                 // SEHFinallyStmt
       STMT_SEH_TRY,                     // SEHTryStmt
 
+      // Code transformation directives.
+      STMT_TRANSFORM_EXECUTABLE_DIRECTIVE, // TransformExecutableDirective
+
       // OpenMP directives
       STMT_OMP_PARALLEL_DIRECTIVE,
       STMT_OMP_SIMD_DIRECTIVE,
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to