jyu2 updated this revision to Diff 334969.
jyu2 added a comment.

Thank you, Alexey for the review!!!
This changes have been addressed as you suggested.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D99679/new/

https://reviews.llvm.org/D99679

Files:
  clang/include/clang/AST/OpenMPClause.h
  clang/include/clang/AST/RecursiveASTVisitor.h
  clang/include/clang/Sema/Sema.h
  clang/lib/AST/OpenMPClause.cpp
  clang/lib/AST/StmtProfile.cpp
  clang/lib/Basic/OpenMPKinds.cpp
  clang/lib/Parse/ParseOpenMP.cpp
  clang/lib/Sema/SemaOpenMP.cpp
  clang/lib/Sema/TreeTransform.h
  clang/lib/Serialization/ASTReader.cpp
  clang/lib/Serialization/ASTWriter.cpp
  clang/test/OpenMP/dispatch_ast_print.cpp
  clang/test/OpenMP/dispatch_messages.cpp
  clang/tools/libclang/CIndex.cpp
  flang/lib/Semantics/check-omp-structure.cpp
  llvm/include/llvm/Frontend/OpenMP/OMP.td

Index: llvm/include/llvm/Frontend/OpenMP/OMP.td
===================================================================
--- llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -276,6 +276,10 @@
 def OMPC_Destroy : Clause<"destroy"> {
   let clangClass = "OMPDestroyClause";
 }
+def OMPC_Novariants : Clause<"novariants"> {
+  let clangClass = "OMPNovariantsClause";
+  let flangClass = "ScalarLogicalExpr";
+}
 def OMPC_Detach : Clause<"detach"> {
   let clangClass = "OMPDetachClause";
 }
@@ -1660,7 +1664,8 @@
     VersionedClause<OMPC_Device>,
     VersionedClause<OMPC_IsDevicePtr>,
     VersionedClause<OMPC_NoWait>,
-    VersionedClause<OMPC_Depend>
+    VersionedClause<OMPC_Depend>,
+    VersionedClause<OMPC_Novariants>
   ];
 }
 def OMP_Unknown : Directive<"unknown"> {
Index: flang/lib/Semantics/check-omp-structure.cpp
===================================================================
--- flang/lib/Semantics/check-omp-structure.cpp
+++ flang/lib/Semantics/check-omp-structure.cpp
@@ -729,6 +729,7 @@
 CHECK_SIMPLE_CLAUSE(Write, OMPC_write)
 CHECK_SIMPLE_CLAUSE(Init, OMPC_init)
 CHECK_SIMPLE_CLAUSE(Use, OMPC_use)
+CHECK_SIMPLE_CLAUSE(Novariants, OMPC_novariants)
 
 CHECK_REQ_SCALAR_INT_CLAUSE(Allocator, OMPC_allocator)
 CHECK_REQ_SCALAR_INT_CLAUSE(Grainsize, OMPC_grainsize)
Index: clang/tools/libclang/CIndex.cpp
===================================================================
--- clang/tools/libclang/CIndex.cpp
+++ clang/tools/libclang/CIndex.cpp
@@ -2291,6 +2291,10 @@
     Visitor->AddStmt(C->getInteropVar());
 }
 
+void OMPClauseEnqueue::VisitOMPNovariantsClause(const OMPNovariantsClause *C) {
+  Visitor->AddStmt(C->getCondition());
+}
+
 void OMPClauseEnqueue::VisitOMPUnifiedAddressClause(
     const OMPUnifiedAddressClause *) {}
 
Index: clang/test/OpenMP/dispatch_messages.cpp
===================================================================
--- clang/test/OpenMP/dispatch_messages.cpp
+++ clang/test/OpenMP/dispatch_messages.cpp
@@ -28,6 +28,24 @@
   // expected-error@+1 {{cannot contain more than one 'nowait' clause}}
   #pragma omp dispatch nowait device(dnum) nowait
   disp_call();
+
+  // expected-error@+1 {{expected '(' after 'novariants'}}
+  #pragma omp dispatch novariants
+  disp_call();
+
+  // expected-error@+3 {{expected expression}}
+  // expected-error@+2 {{expected ')'}}
+  // expected-note@+1 {{to match this '('}}
+  #pragma omp dispatch novariants (
+  disp_call();
+
+  // expected-error@+1 {{cannot contain more than one 'novariants' clause}}
+  #pragma omp dispatch novariants(dnum> 4) novariants(3)
+  disp_call();
+
+  // expected-error@+1 {{use of undeclared identifier 'x'}}
+  #pragma omp dispatch novariants(x)
+  disp_call();
 }
 
 void testit_two() {
Index: clang/test/OpenMP/dispatch_ast_print.cpp
===================================================================
--- clang/test/OpenMP/dispatch_ast_print.cpp
+++ clang/test/OpenMP/dispatch_ast_print.cpp
@@ -51,20 +51,22 @@
 void test_one()
 {
   int aaa, bbb, var;
-  //PRINT: #pragma omp dispatch depend(in : var) nowait
+  //PRINT: #pragma omp dispatch depend(in : var) nowait novariants(aaa > 5)
   //DUMP: OMPDispatchDirective
   //DUMP: OMPDependClause
   //DUMP: OMPNowaitClause
-  #pragma omp dispatch depend(in:var) nowait
+  //DUMP: OMPNovariantsClause
+  #pragma omp dispatch depend(in:var) nowait novariants(aaa > 5)
   foo(aaa, &bbb);
 
   int *dp = get_device_ptr();
   int dev = get_device();
-  //PRINT: #pragma omp dispatch device(dev) is_device_ptr(dp)
+  //PRINT: #pragma omp dispatch device(dev) is_device_ptr(dp) novariants(dev > 10)
   //DUMP: OMPDispatchDirective
   //DUMP: OMPDeviceClause
   //DUMP: OMPIs_device_ptrClause
-  #pragma omp dispatch device(dev) is_device_ptr(dp)
+  //DUMP: OMPNovariantsClause
+  #pragma omp dispatch device(dev) is_device_ptr(dp) novariants(dev > 10)
   foo(aaa, dp);
 
   //PRINT: #pragma omp dispatch
Index: clang/lib/Serialization/ASTWriter.cpp
===================================================================
--- clang/lib/Serialization/ASTWriter.cpp
+++ clang/lib/Serialization/ASTWriter.cpp
@@ -6237,6 +6237,12 @@
   Record.AddSourceLocation(C->getVarLoc());
 }
 
+void OMPClauseWriter::VisitOMPNovariantsClause(OMPNovariantsClause *C) {
+  VisitOMPClauseWithPreInit(C);
+  Record.AddStmt(C->getCondition());
+  Record.AddSourceLocation(C->getLParenLoc());
+}
+
 void OMPClauseWriter::VisitOMPPrivateClause(OMPPrivateClause *C) {
   Record.push_back(C->varlist_size());
   Record.AddSourceLocation(C->getLParenLoc());
Index: clang/lib/Serialization/ASTReader.cpp
===================================================================
--- clang/lib/Serialization/ASTReader.cpp
+++ clang/lib/Serialization/ASTReader.cpp
@@ -11977,6 +11977,9 @@
   case llvm::omp::OMPC_destroy:
     C = new (Context) OMPDestroyClause();
     break;
+  case llvm::omp::OMPC_novariants:
+    C = new (Context) OMPNovariantsClause();
+    break;
   case llvm::omp::OMPC_detach:
     C = new (Context) OMPDetachClause();
     break;
@@ -12162,6 +12165,12 @@
   C->setVarLoc(Record.readSourceLocation());
 }
 
+void OMPClauseReader::VisitOMPNovariantsClause(OMPNovariantsClause *C) {
+  VisitOMPClauseWithPreInit(C);
+  C->setCondition(Record.readSubExpr());
+  C->setLParenLoc(Record.readSourceLocation());
+}
+
 void OMPClauseReader::VisitOMPUnifiedAddressClause(OMPUnifiedAddressClause *) {}
 
 void OMPClauseReader::VisitOMPUnifiedSharedMemoryClause(
Index: clang/lib/Sema/TreeTransform.h
===================================================================
--- clang/lib/Sema/TreeTransform.h
+++ clang/lib/Sema/TreeTransform.h
@@ -2208,6 +2208,18 @@
                                               VarLoc, EndLoc);
   }
 
+  /// Build a new OpenMP 'novariants' clause.
+  ///
+  /// By default, performs semantic analysis to build the new OpenMP clause.
+  /// Subclasses may override this routine to provide different behavior.
+  OMPClause *RebuildOMPNovariantsClause(Expr *Condition,
+                                        SourceLocation StartLoc,
+                                        SourceLocation LParenLoc,
+                                        SourceLocation EndLoc) {
+    return getSema().ActOnOpenMPNovariantsClause(Condition, StartLoc, LParenLoc,
+                                                 EndLoc);
+  }
+
   /// Rebuild the operand to an Objective-C \@synchronized statement.
   ///
   /// By default, performs semantic analysis to build the new statement.
@@ -9377,6 +9389,16 @@
                                               C->getEndLoc());
 }
 
+template <typename Derived>
+OMPClause *
+TreeTransform<Derived>::TransformOMPNovariantsClause(OMPNovariantsClause *C) {
+  ExprResult Cond = getDerived().TransformExpr(C->getCondition());
+  if (Cond.isInvalid())
+    return nullptr;
+  return getDerived().RebuildOMPNovariantsClause(
+      Cond.get(), C->getBeginLoc(), C->getLParenLoc(), C->getEndLoc());
+}
+
 template <typename Derived>
 OMPClause *TreeTransform<Derived>::TransformOMPUnifiedAddressClause(
     OMPUnifiedAddressClause *C) {
Index: clang/lib/Sema/SemaOpenMP.cpp
===================================================================
--- clang/lib/Sema/SemaOpenMP.cpp
+++ clang/lib/Sema/SemaOpenMP.cpp
@@ -6173,6 +6173,7 @@
       case OMPC_num_tasks:
       case OMPC_final:
       case OMPC_priority:
+      case OMPC_novariants:
         // Do not analyze if no parent parallel directive.
         if (isOpenMPParallelDirective(Kind))
           break;
@@ -12785,6 +12786,9 @@
   case OMPC_detach:
     Res = ActOnOpenMPDetachClause(Expr, StartLoc, LParenLoc, EndLoc);
     break;
+  case OMPC_novariants:
+    Res = ActOnOpenMPNovariantsClause(Expr, StartLoc, LParenLoc, EndLoc);
+    break;
   case OMPC_device:
   case OMPC_if:
   case OMPC_default:
@@ -13557,6 +13561,15 @@
       llvm_unreachable("Unknown OpenMP directive");
     }
     break;
+  case OMPC_novariants:
+    switch (DKind) {
+    case OMPD_dispatch:
+      CaptureRegion = OMPD_task;
+      break;
+    default:
+      llvm_unreachable("Unknown OpenMP directive");
+    }
+    break;
   case OMPC_firstprivate:
   case OMPC_lastprivate:
   case OMPC_reduction:
@@ -14061,6 +14074,7 @@
   case OMPC_match:
   case OMPC_nontemporal:
   case OMPC_destroy:
+  case OMPC_novariants:
   case OMPC_detach:
   case OMPC_inclusive:
   case OMPC_exclusive:
@@ -14317,6 +14331,7 @@
   case OMPC_nontemporal:
   case OMPC_order:
   case OMPC_destroy:
+  case OMPC_novariants:
   case OMPC_detach:
   case OMPC_inclusive:
   case OMPC_exclusive:
@@ -14558,6 +14573,7 @@
   case OMPC_match:
   case OMPC_nontemporal:
   case OMPC_order:
+  case OMPC_novariants:
   case OMPC_detach:
   case OMPC_inclusive:
   case OMPC_exclusive:
@@ -14848,6 +14864,36 @@
       OMPDestroyClause(InteropVar, StartLoc, LParenLoc, VarLoc, EndLoc);
 }
 
+OMPClause *Sema::ActOnOpenMPNovariantsClause(Expr *Condition,
+                                             SourceLocation StartLoc,
+                                             SourceLocation LParenLoc,
+                                             SourceLocation EndLoc) {
+  Expr *ValExpr = Condition;
+  Stmt *HelperValStmt = nullptr;
+  OpenMPDirectiveKind CaptureRegion = OMPD_unknown;
+  if (!Condition->isValueDependent() && !Condition->isTypeDependent() &&
+      !Condition->isInstantiationDependent() &&
+      !Condition->containsUnexpandedParameterPack()) {
+    ExprResult Val = CheckBooleanCondition(StartLoc, Condition);
+    if (Val.isInvalid())
+      return nullptr;
+
+    ValExpr = MakeFullExpr(Val.get()).get();
+
+    OpenMPDirectiveKind DKind = DSAStack->getCurrentDirective();
+    CaptureRegion = getOpenMPCaptureRegionForClause(DKind, OMPC_novariants,
+                                                    LangOpts.OpenMP);
+    if (CaptureRegion != OMPD_unknown && !CurContext->isDependentContext()) {
+      ValExpr = MakeFullExpr(ValExpr).get();
+      llvm::MapVector<const Expr *, DeclRefExpr *> Captures;
+      HelperValStmt = buildPreInits(Context, Captures);
+    }
+  }
+
+  return new (Context) OMPNovariantsClause(
+      ValExpr, HelperValStmt, CaptureRegion, StartLoc, LParenLoc, EndLoc);
+}
+
 OMPClause *Sema::ActOnOpenMPVarListClause(
     OpenMPClauseKind Kind, ArrayRef<Expr *> VarList, Expr *DepModOrTailExpr,
     const OMPVarListLocTy &Locs, SourceLocation ColonLoc,
@@ -15018,6 +15064,7 @@
   case OMPC_match:
   case OMPC_order:
   case OMPC_destroy:
+  case OMPC_novariants:
   case OMPC_detach:
   case OMPC_uses_allocators:
   default:
Index: clang/lib/Parse/ParseOpenMP.cpp
===================================================================
--- clang/lib/Parse/ParseOpenMP.cpp
+++ clang/lib/Parse/ParseOpenMP.cpp
@@ -2776,6 +2776,7 @@
   case OMPC_allocator:
   case OMPC_depobj:
   case OMPC_detach:
+  case OMPC_novariants:
     // OpenMP [2.5, Restrictions]
     //  At most one num_threads clause can appear on the directive.
     // OpenMP [2.8.1, simd construct, Restrictions]
@@ -2798,6 +2799,8 @@
     // At most one allocator clause can appear on the directive.
     // OpenMP 5.0, 2.10.1 task Construct, Restrictions.
     // At most one detach clause can appear on the directive.
+    // OpenMP 5.1, 2.3.6 dispatch Construct, Restrictions.
+    // At most one novariants clause can appear on a dispatch directive.
     if (!FirstClause) {
       Diag(Tok, diag::err_omp_more_one_clause)
           << getOpenMPDirectiveName(DKind) << getOpenMPClauseName(CKind) << 0;
Index: clang/lib/Basic/OpenMPKinds.cpp
===================================================================
--- clang/lib/Basic/OpenMPKinds.cpp
+++ clang/lib/Basic/OpenMPKinds.cpp
@@ -176,6 +176,7 @@
   case OMPC_match:
   case OMPC_nontemporal:
   case OMPC_destroy:
+  case OMPC_novariants:
   case OMPC_detach:
   case OMPC_inclusive:
   case OMPC_exclusive:
@@ -418,6 +419,7 @@
   case OMPC_nontemporal:
   case OMPC_destroy:
   case OMPC_detach:
+  case OMPC_novariants:
   case OMPC_inclusive:
   case OMPC_exclusive:
   case OMPC_uses_allocators:
Index: clang/lib/AST/StmtProfile.cpp
===================================================================
--- clang/lib/AST/StmtProfile.cpp
+++ clang/lib/AST/StmtProfile.cpp
@@ -483,6 +483,12 @@
     Profiler->VisitStmt(Evt);
 }
 
+void OMPClauseProfiler::VisitOMPNovariantsClause(const OMPNovariantsClause *C) {
+  VistOMPClauseWithPreInit(C);
+  if (C->getCondition())
+    Profiler->VisitStmt(C->getCondition());
+}
+
 void OMPClauseProfiler::VisitOMPDefaultClause(const OMPDefaultClause *C) { }
 
 void OMPClauseProfiler::VisitOMPProcBindClause(const OMPProcBindClause *C) { }
Index: clang/lib/AST/OpenMPClause.cpp
===================================================================
--- clang/lib/AST/OpenMPClause.cpp
+++ clang/lib/AST/OpenMPClause.cpp
@@ -96,6 +96,8 @@
     return static_cast<const OMPFinalClause *>(C);
   case OMPC_priority:
     return static_cast<const OMPPriorityClause *>(C);
+  case OMPC_novariants:
+    return static_cast<const OMPNovariantsClause *>(C);
   case OMPC_default:
   case OMPC_proc_bind:
   case OMPC_safelen:
@@ -244,6 +246,7 @@
   case OMPC_nontemporal:
   case OMPC_order:
   case OMPC_destroy:
+  case OMPC_novariants:
   case OMPC_detach:
   case OMPC_inclusive:
   case OMPC_exclusive:
@@ -300,6 +303,12 @@
   return child_range(&Priority, &Priority + 1);
 }
 
+OMPClause::child_range OMPNovariantsClause::used_children() {
+  if (Stmt **C = getAddrOfExprAsWritten(getPreInitStmt()))
+    return child_range(C, C + 1);
+  return child_range(&Condition, &Condition + 1);
+}
+
 OMPOrderedClause *OMPOrderedClause::Create(const ASTContext &C, Expr *Num,
                                            unsigned NumLoops,
                                            SourceLocation StartLoc,
@@ -1816,6 +1825,15 @@
   }
 }
 
+void OMPClausePrinter::VisitOMPNovariantsClause(OMPNovariantsClause *Node) {
+  OS << "novariants";
+  if (Expr *E = Node->getCondition()) {
+    OS << "(";
+    E->printPretty(OS, nullptr, Policy, 0);
+    OS << ")";
+  }
+}
+
 template<typename T>
 void OMPClausePrinter::VisitOMPClauseList(T *Node, char StartSym) {
   for (typename T::varlist_iterator I = Node->varlist_begin(),
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -11017,7 +11017,11 @@
                                       SourceLocation LParenLoc,
                                       SourceLocation VarLoc,
                                       SourceLocation EndLoc);
-
+  /// Called on well-formed 'novariants' clause.
+  OMPClause *ActOnOpenMPNovariantsClause(Expr *Condition,
+                                         SourceLocation StartLoc,
+                                         SourceLocation LParenLoc,
+                                         SourceLocation EndLoc);
   /// Called on well-formed 'threads' clause.
   OMPClause *ActOnOpenMPThreadsClause(SourceLocation StartLoc,
                                       SourceLocation EndLoc);
Index: clang/include/clang/AST/RecursiveASTVisitor.h
===================================================================
--- clang/include/clang/AST/RecursiveASTVisitor.h
+++ clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3218,6 +3218,14 @@
   return true;
 }
 
+template <typename Derived>
+bool RecursiveASTVisitor<Derived>::VisitOMPNovariantsClause(
+    OMPNovariantsClause *C) {
+  TRY_TO(VisitOMPClauseWithPreInit(C));
+  TRY_TO(TraverseStmt(C->getCondition()));
+  return true;
+}
+
 template <typename Derived>
 template <typename T>
 bool RecursiveASTVisitor<Derived>::VisitOMPClauseList(T *Node) {
Index: clang/include/clang/AST/OpenMPClause.h
===================================================================
--- clang/include/clang/AST/OpenMPClause.h
+++ clang/include/clang/AST/OpenMPClause.h
@@ -7649,6 +7649,76 @@
   }
 };
 
+/// This represents 'novariants' clause in the '#pragma omp ...' directive.
+///
+/// \code
+/// #pragma omp dispatch novariants(a > 5)
+/// \endcode
+/// In this example directive '#pragma omp dispatch' has simple 'novariants'
+/// clause with condition 'a > 5'.
+class OMPNovariantsClause final: public OMPClause, public OMPClauseWithPreInit {
+  friend class OMPClauseReader;
+
+  /// Location of '('.
+  SourceLocation LParenLoc;
+
+  /// Condition of the 'if' clause.
+  Stmt *Condition = nullptr;
+
+  /// Set condition.
+  void setCondition(Expr *Cond) { Condition = Cond; }
+
+  /// Sets the location of '('.
+  void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
+
+public:
+  /// Build 'novariants' clause with condition \a Cond.
+  ///
+  /// \param Cond Condition of the clause.
+  /// \param HelperCond Helper condition for the construct.
+  /// \param CaptureRegion Innermost OpenMP region where expressions in this
+  /// clause must be captured.
+  /// \param StartLoc Starting location of the clause.
+  /// \param LParenLoc Location of '('.
+  /// \param EndLoc Ending location of the clause.
+  OMPNovariantsClause(Expr *Cond, Stmt *HelperCond,
+                      OpenMPDirectiveKind CaptureRegion,
+                      SourceLocation StartLoc, SourceLocation LParenLoc,
+                      SourceLocation EndLoc)
+      : OMPClause(llvm::omp::OMPC_novariants, StartLoc, EndLoc),
+        OMPClauseWithPreInit(this), LParenLoc(LParenLoc), Condition(Cond) {
+    setPreInitStmt(HelperCond, CaptureRegion);
+  }
+
+  /// Build an empty clause.
+  OMPNovariantsClause()
+      : OMPClause(llvm::omp::OMPC_novariants, SourceLocation(),
+                  SourceLocation()),
+        OMPClauseWithPreInit(this) {}
+
+  /// Returns the location of '('.
+  SourceLocation getLParenLoc() const { return LParenLoc; }
+
+  /// Returns condition.
+  Expr *getCondition() const { return cast_or_null<Expr>(Condition); }
+
+  child_range children() { return child_range(&Condition, &Condition + 1); }
+
+  const_child_range children() const {
+    return const_child_range(&Condition, &Condition + 1);
+  }
+
+  child_range used_children();
+  const_child_range used_children() const {
+    auto Children = const_cast<OMPNovariantsClause *>(this)->used_children();
+    return const_child_range(Children.begin(), Children.end());
+  }
+
+  static bool classof(const OMPClause *T) {
+    return T->getClauseKind() == llvm::omp::OMPC_novariants;
+  }
+};
+
 /// This represents 'detach' clause in the '#pragma omp task' directive.
 ///
 /// \code
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to