koops updated this revision to Diff 557808.
koops added a comment.

1. In class OMPFailClause, the was a duplication in the storage of parameter to 
the fail clause because the parameter was stored as FailParameterKind and 
MemoryOrderClause (FailMemoryOrderClause). There was a possibility of these two 
being out of sync along with confusion to the reader of the code. Hence storing 
FailMemoryOrderClause only. The FailParameterKind (of type ClauseKind) is now 
obtained from the FailMemoryOrderClause when needed.
2. In Visit(const OMPClause *C) there is a check for if (const auto *OMPC = 
dyn_cast<OMPFailClause>(C)). This is mainly done to visit the parameter of the 
FailClause which is a memory Order Clause.


CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D123235/new/

https://reviews.llvm.org/D123235

Files:
  clang/include/clang/AST/ASTNodeTraverser.h
  clang/include/clang/AST/OpenMPClause.h
  clang/include/clang/AST/RecursiveASTVisitor.h
  clang/include/clang/Basic/DiagnosticSemaKinds.td
  clang/include/clang/Basic/OpenMPKinds.def
  clang/include/clang/Parse/Parser.h
  clang/include/clang/Sema/Sema.h
  clang/lib/AST/OpenMPClause.cpp
  clang/lib/AST/StmtProfile.cpp
  clang/lib/Basic/OpenMPKinds.cpp
  clang/lib/CodeGen/CGStmtOpenMP.cpp
  clang/lib/Parse/ParseOpenMP.cpp
  clang/lib/Sema/SemaOpenMP.cpp
  clang/lib/Sema/TreeTransform.h
  clang/lib/Serialization/ASTReader.cpp
  clang/lib/Serialization/ASTWriter.cpp
  clang/test/OpenMP/atomic_ast_print.cpp
  clang/test/OpenMP/atomic_messages.cpp
  clang/tools/libclang/CIndex.cpp
  flang/lib/Semantics/check-omp-structure.cpp
  llvm/include/llvm/Frontend/OpenMP/OMP.td

Index: llvm/include/llvm/Frontend/OpenMP/OMP.td
===================================================================
--- llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -209,6 +209,7 @@
 def OMPC_Update : Clause<"update"> { let clangClass = "OMPUpdateClause"; }
 def OMPC_Capture : Clause<"capture"> { let clangClass = "OMPCaptureClause"; }
 def OMPC_Compare : Clause<"compare"> { let clangClass = "OMPCompareClause"; }
+def OMPC_Fail : Clause<"fail"> { let clangClass = "OMPFailClause"; }
 def OMPC_SeqCst : Clause<"seq_cst"> { let clangClass = "OMPSeqCstClause"; }
 def OMPC_AcqRel : Clause<"acq_rel"> { let clangClass = "OMPAcqRelClause"; }
 def OMPC_Acquire : Clause<"acquire"> { let clangClass = "OMPAcquireClause"; }
@@ -637,7 +638,8 @@
     VersionedClause<OMPC_Acquire, 50>,
     VersionedClause<OMPC_Release, 50>,
     VersionedClause<OMPC_Relaxed, 50>,
-    VersionedClause<OMPC_Hint, 50>
+    VersionedClause<OMPC_Hint, 50>,
+    VersionedClause<OMPC_Fail, 51>
   ];
 }
 def OMP_Target : Directive<"target"> {
Index: flang/lib/Semantics/check-omp-structure.cpp
===================================================================
--- flang/lib/Semantics/check-omp-structure.cpp
+++ flang/lib/Semantics/check-omp-structure.cpp
@@ -2164,6 +2164,7 @@
 CHECK_SIMPLE_CLAUSE(Doacross, OMPC_doacross)
 CHECK_SIMPLE_CLAUSE(OmpxAttribute, OMPC_ompx_attribute)
 CHECK_SIMPLE_CLAUSE(OmpxBare, OMPC_ompx_bare)
+CHECK_SIMPLE_CLAUSE(Fail, OMPC_fail)
 
 CHECK_REQ_SCALAR_INT_CLAUSE(Grainsize, OMPC_grainsize)
 CHECK_REQ_SCALAR_INT_CLAUSE(NumTasks, OMPC_num_tasks)
Index: clang/tools/libclang/CIndex.cpp
===================================================================
--- clang/tools/libclang/CIndex.cpp
+++ clang/tools/libclang/CIndex.cpp
@@ -2402,6 +2402,8 @@
 
 void OMPClauseEnqueue::VisitOMPCompareClause(const OMPCompareClause *) {}
 
+void OMPClauseEnqueue::VisitOMPFailClause(const OMPFailClause *) {}
+
 void OMPClauseEnqueue::VisitOMPSeqCstClause(const OMPSeqCstClause *) {}
 
 void OMPClauseEnqueue::VisitOMPAcqRelClause(const OMPAcqRelClause *) {}
Index: clang/test/OpenMP/atomic_messages.cpp
===================================================================
--- clang/test/OpenMP/atomic_messages.cpp
+++ clang/test/OpenMP/atomic_messages.cpp
@@ -958,6 +958,25 @@
 // expected-error@+1 {{directive '#pragma omp atomic' cannot contain more than one 'capture' clause}}
 #pragma omp atomic compare compare capture capture
   { v = a; if (a > b) a = b; }
+// expected-error@+1 {{expected 'compare' clause with the 'fail' modifier}}
+#pragma omp atomic fail(seq_cst)
+  if(v == a) { v = a; }
+// expected-error@+2 {{expected '(' after 'fail'}}
+// expected-error@+1 {{expected a memory order clause}}
+#pragma omp atomic compare fail
+  if(v < a) { v = a; }
+// expected-error@+1 {{expected a memory order clause}}
+#pragma omp atomic compare fail(capture)
+  if(v < a) { v = a; }
+ // expected-error@+2 {{expected ')' after 'atomic compare fail'}}
+ // expected-warning@+1 {{extra tokens at the end of '#pragma omp atomic' are ignored}}
+#pragma omp atomic compare fail(seq_cst | acquire)
+  if(v < a) { v = a; }
+// expected-error@+1 {{directive '#pragma omp atomic' cannot contain more than one 'fail' clause}}
+#pragma omp atomic compare fail(relaxed) fail(seq_cst)
+  if(v < a) { v = a; }
+
+
 #endif
   // expected-note@+1 {{in instantiation of function template specialization 'mixed<int>' requested here}}
   return mixed<int>();
Index: clang/test/OpenMP/atomic_ast_print.cpp
===================================================================
--- clang/test/OpenMP/atomic_ast_print.cpp
+++ clang/test/OpenMP/atomic_ast_print.cpp
@@ -226,6 +226,16 @@
   { v = a; if (a < b) { a = b; } }
 #pragma omp atomic compare capture hint(6)
   { v = a == b; if (v) a = c; }
+#pragma omp atomic compare fail(acq_rel)
+  { if (a < c) { a = c; } }
+#pragma omp atomic compare fail(acquire)
+  { if (a < c) { a = c; } }
+#pragma omp atomic compare fail(release)
+  { if (a < c) { a = c; } }
+#pragma omp atomic compare fail(relaxed)
+  { if (a < c) { a = c; } }
+#pragma omp atomic compare fail(seq_cst)
+  { if (a < c) { a = c; } }
 #endif
   return T();
 }
@@ -1099,6 +1109,16 @@
   { v = a; if (a < b) { a = b; } }
 #pragma omp atomic compare capture hint(6)
   { v = a == b; if (v) a = c; }
+#pragma omp atomic compare fail(acq_rel)
+  if(a < b) { a = b; }
+#pragma omp atomic compare fail(acquire)
+  if(a < b) { a = b; }
+#pragma omp atomic compare fail(release)
+  if(a < b) { a = b; }
+#pragma omp atomic compare fail(relaxed)
+  if(a < b) { a = b; }
+#pragma omp atomic compare fail(seq_cst)
+  if(a < b) { a = b; }
 #endif
   // CHECK-NEXT: #pragma omp atomic
   // CHECK-NEXT: a++;
@@ -1429,6 +1449,26 @@
   // CHECK-51-NEXT: if (v)
   // CHECK-51-NEXT: a = c;
   // CHECK-51-NEXT: }
+  // CHECK-51-NEXT: #pragma omp atomic compare fail(acquire)
+  // CHECK-51-NEXT: if (a < b) {
+  // CHECK-51-NEXT: a = b;
+  // CHECK-51-NEXT: }
+  // CHECK-51-NEXT: #pragma omp atomic compare fail(acquire)
+  // CHECK-51-NEXT: if (a < b) {
+  // CHECK-51-NEXT: a = b;
+  // CHECK-51-NEXT: }
+  // CHECK-51-NEXT: #pragma omp atomic compare fail(relaxed)
+  // CHECK-51-NEXT: if (a < b) {
+  // CHECK-51-NEXT: a = b;
+  // CHECK-51-NEXT: }
+  // CHECK-51-NEXT: #pragma omp atomic compare fail(relaxed)
+  // CHECK-51-NEXT: if (a < b) {
+  // CHECK-51-NEXT: a = b;
+  // CHECK-51-NEXT: }
+  // CHECK-51-NEXT: #pragma omp atomic compare fail(seq_cst)
+  // CHECK-51-NEXT: if (a < b) {
+  // CHECK-51-NEXT: a = b;
+  // CHECK-51-NEXT: }
   // expect-note@+1 {{in instantiation of function template specialization 'foo<int>' requested here}}
   return foo(a);
 }
Index: clang/lib/Serialization/ASTWriter.cpp
===================================================================
--- clang/lib/Serialization/ASTWriter.cpp
+++ clang/lib/Serialization/ASTWriter.cpp
@@ -6606,6 +6606,13 @@
 
 void OMPClauseWriter::VisitOMPCompareClause(OMPCompareClause *) {}
 
+// Save the parameter of fail clause.
+void OMPClauseWriter::VisitOMPFailClause(OMPFailClause *C) {
+  Record.AddSourceLocation(C->getLParenLoc());
+  Record.AddSourceLocation(C->getArgumentLoc());
+  Record.writeEnum(C->getMemoryOrderClauseKind());
+}
+
 void OMPClauseWriter::VisitOMPSeqCstClause(OMPSeqCstClause *) {}
 
 void OMPClauseWriter::VisitOMPAcqRelClause(OMPAcqRelClause *) {}
Index: clang/lib/Serialization/ASTReader.cpp
===================================================================
--- clang/lib/Serialization/ASTReader.cpp
+++ clang/lib/Serialization/ASTReader.cpp
@@ -10275,6 +10275,9 @@
   case llvm::omp::OMPC_compare:
     C = new (Context) OMPCompareClause();
     break;
+  case llvm::omp::OMPC_fail:
+    C = OMPFailClause::CreateEmpty(Context);
+    break;
   case llvm::omp::OMPC_seq_cst:
     C = new (Context) OMPSeqCstClause();
     break;
@@ -10668,6 +10671,16 @@
 
 void OMPClauseReader::VisitOMPCompareClause(OMPCompareClause *) {}
 
+// Read the parameter of fail clause. This will have been saved when
+// OMPClauseWriter is called.
+void OMPClauseReader::VisitOMPFailClause(OMPFailClause *C) {
+  C->setLParenLoc(Record.readSourceLocation());
+  SourceLocation ArgumentLoc = Record.readSourceLocation();
+  C->setArgumentLoc(ArgumentLoc);
+  OpenMPClauseKind CKind = Record.readEnum<OpenMPClauseKind>();
+  C->setMemoryOrderClauseKind(CKind);
+}
+
 void OMPClauseReader::VisitOMPSeqCstClause(OMPSeqCstClause *) {}
 
 void OMPClauseReader::VisitOMPAcqRelClause(OMPAcqRelClause *) {}
Index: clang/lib/Sema/TreeTransform.h
===================================================================
--- clang/lib/Sema/TreeTransform.h
+++ clang/lib/Sema/TreeTransform.h
@@ -9876,6 +9876,12 @@
   return C;
 }
 
+template <typename Derived>
+OMPClause *TreeTransform<Derived>::TransformOMPFailClause(OMPFailClause *C) {
+  // No need to rebuild this clause, no template-dependent parameters.
+  return C;
+}
+
 template <typename Derived>
 OMPClause *
 TreeTransform<Derived>::TransformOMPSeqCstClause(OMPSeqCstClause *C) {
Index: clang/lib/Sema/SemaOpenMP.cpp
===================================================================
--- clang/lib/Sema/SemaOpenMP.cpp
+++ clang/lib/Sema/SemaOpenMP.cpp
@@ -12619,6 +12619,33 @@
 }
 } // namespace
 
+void Sema::checkFailClauseParameters(ArrayRef<OMPClause *> Clauses) {
+  int NoOfFails = 0;
+
+  for (const OMPClause *C : Clauses) {
+    const auto *FC = dyn_cast<OMPFailClause>(C);
+    if (!FC)
+      continue;
+    const OMPClause *MemOrderC = FC->getMemoryOrderClause();
+    NoOfFails++;
+    if (NoOfFails > 1) {
+      Diag(FC->getBeginLoc(), diag::err_omp_atomic_fail_extra_clauses);
+    }
+    /* Clauses contains OMPC_fail and the subclause */
+    if (MemOrderC) {
+      OpenMPClauseKind ClauseKind = MemOrderC->getClauseKind();
+      if (!((ClauseKind == OMPC_acq_rel) || (ClauseKind == OMPC_acquire) ||
+            (ClauseKind == OMPC_relaxed) || (ClauseKind == OMPC_release) ||
+            (ClauseKind == OMPC_seq_cst))) {
+        Diag(MemOrderC->getBeginLoc(),
+             diag::err_omp_atomic_fail_wrong_or_no_clauses);
+      }
+    } else if (FC->getMemoryOrderClauseKind() == OMPC_unknown) {
+      Diag(FC->getBeginLoc(), diag::err_omp_atomic_fail_wrong_or_no_clauses);
+    }
+  }
+}
+
 StmtResult Sema::ActOnOpenMPAtomicDirective(ArrayRef<OMPClause *> Clauses,
                                             Stmt *AStmt,
                                             SourceLocation StartLoc,
@@ -12637,6 +12664,8 @@
   SourceLocation AtomicKindLoc;
   OpenMPClauseKind MemOrderKind = OMPC_unknown;
   SourceLocation MemOrderLoc;
+  llvm::omp::Clause SubClause = OMPC_unknown;
+  SourceLocation SubClauseLoc;
   bool MutexClauseEncountered = false;
   llvm::SmallSet<OpenMPClauseKind, 2> EncounteredAtomicKinds;
   for (const OMPClause *C : Clauses) {
@@ -12665,6 +12694,16 @@
       }
       break;
     }
+    case OMPC_fail: {
+      if (AtomicKind != OMPC_compare) {
+        Diag(C->getBeginLoc(), diag::err_omp_atomic_fail_no_compare)
+            << SourceRange(C->getBeginLoc(), C->getEndLoc());
+        return StmtError();
+      }
+      SubClause = OMPC_fail;
+      SubClauseLoc = C->getBeginLoc();
+      break;
+    }
     case OMPC_seq_cst:
     case OMPC_acq_rel:
     case OMPC_acquire:
@@ -13153,6 +13192,8 @@
       CE = Checker.getCond();
       // We reuse IsXLHSInRHSPart to tell if it is in the form 'x ordop expr'.
       IsXLHSInRHSPart = Checker.isXBinopExpr();
+      if (SubClause == OMPC_fail)
+        checkFailClauseParameters(Clauses);
     }
   }
 
@@ -16866,6 +16907,11 @@
         static_cast<OpenMPAtomicDefaultMemOrderClauseKind>(Argument),
         ArgumentLoc, StartLoc, LParenLoc, EndLoc);
     break;
+  case OMPC_fail:
+    Res = ActOnOpenMPFailClause(
+        static_cast<OpenMPClauseKind>(Argument),
+        ArgumentLoc, StartLoc, LParenLoc, EndLoc);
+    break;
   case OMPC_update:
     Res = ActOnOpenMPUpdateClause(static_cast<OpenMPDependClauseKind>(Argument),
                                   ArgumentLoc, StartLoc, LParenLoc, EndLoc);
@@ -17506,6 +17552,9 @@
   case OMPC_compare:
     Res = ActOnOpenMPCompareClause(StartLoc, EndLoc);
     break;
+  case OMPC_fail:
+    Res = ActOnOpenMPFailClause(StartLoc, EndLoc);
+    break;
   case OMPC_seq_cst:
     Res = ActOnOpenMPSeqCstClause(StartLoc, EndLoc);
     break;
@@ -17666,6 +17715,19 @@
   return new (Context) OMPCompareClause(StartLoc, EndLoc);
 }
 
+OMPClause *Sema::ActOnOpenMPFailClause(SourceLocation StartLoc,
+                                       SourceLocation EndLoc) {
+  return OMPFailClause::Create(Context, StartLoc, EndLoc);
+}
+
+OMPClause *Sema::ActOnOpenMPFailClause(
+      OpenMPClauseKind Parameter, SourceLocation KindLoc,
+      SourceLocation StartLoc, SourceLocation LParenLoc,
+      SourceLocation EndLoc) {
+
+	return new (Context) OMPFailClause(Parameter, KindLoc, StartLoc, LParenLoc, EndLoc);
+}
+
 OMPClause *Sema::ActOnOpenMPSeqCstClause(SourceLocation StartLoc,
                                          SourceLocation EndLoc) {
   return new (Context) OMPSeqCstClause(StartLoc, EndLoc);
Index: clang/lib/Parse/ParseOpenMP.cpp
===================================================================
--- clang/lib/Parse/ParseOpenMP.cpp
+++ clang/lib/Parse/ParseOpenMP.cpp
@@ -3305,6 +3305,7 @@
   case OMPC_write:
   case OMPC_capture:
   case OMPC_compare:
+  case OMPC_fail:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
@@ -3801,6 +3802,44 @@
                                          Val->Loc, Val->RLoc);
 }
 
+OMPClause *Parser::ParseOpenMPFailClause(OMPClause *Clause) {
+
+  auto *FailClause = cast<OMPFailClause>(Clause);
+  SourceLocation LParenLoc;
+  if (Tok.is(tok::l_paren)) {
+    LParenLoc = Tok.getLocation();
+    ConsumeAnyToken();
+  } else {
+    Diag(diag::err_expected_lparen_after) << getOpenMPClauseName(OMPC_fail);
+    return Clause;
+  }
+
+  OpenMPClauseKind CKind = Tok.isAnnotation()
+                               ? OMPC_unknown
+                               : getOpenMPClauseKind(PP.getSpelling(Tok));
+  if (CKind == OMPC_unknown) {
+    Diag(diag::err_omp_expected_clause) << "atomic compare fail";
+    return Clause;
+  }
+  OMPClause *MemoryOrderClause = ParseOpenMPClause(CKind, false);
+  SourceLocation MemOrderLoc;
+  // Store Memory Order SubClause for Sema.
+  if (MemoryOrderClause)
+    MemOrderLoc = Tok.getLocation();
+
+  FailClause->initFailClause(LParenLoc, MemoryOrderClause, MemOrderLoc);
+
+  if (Tok.is(tok::r_paren)) {
+    ConsumeAnyToken();
+  } else {
+    const IdentifierInfo *Arg = Tok.getIdentifierInfo();
+    Diag(Tok, diag::err_expected_rparen_after)
+        << (Arg ? Arg->getName() : "atomic compare fail");
+  }
+
+  return Clause;
+}
+
 /// Parsing of OpenMP clauses like 'ordered'.
 ///
 ///    ordered-clause:
@@ -3833,7 +3872,10 @@
 
   if (ParseOnly)
     return nullptr;
-  return Actions.ActOnOpenMPClause(Kind, Loc, Tok.getLocation());
+  OMPClause *Clause = Actions.ActOnOpenMPClause(Kind, Loc, Tok.getLocation());
+  if (Kind == llvm::omp::Clause::OMPC_fail)
+    Clause = ParseOpenMPFailClause(Clause);
+  return Clause;
 }
 
 /// Parsing of OpenMP clauses with single expressions and some additional
Index: clang/lib/CodeGen/CGStmtOpenMP.cpp
===================================================================
--- clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -6516,6 +6516,10 @@
                              IsPostfixUpdate, IsFailOnly, Loc);
     break;
   }
+  case OMPC_fail: {
+    //TODO
+    break;
+  }
   default:
     llvm_unreachable("Clause is not allowed in 'omp atomic'.");
   }
Index: clang/lib/Basic/OpenMPKinds.cpp
===================================================================
--- clang/lib/Basic/OpenMPKinds.cpp
+++ clang/lib/Basic/OpenMPKinds.cpp
@@ -104,6 +104,11 @@
   .Case(#Name, OMPC_ATOMIC_DEFAULT_MEM_ORDER_##Name)
 #include "clang/Basic/OpenMPKinds.def"
         .Default(OMPC_ATOMIC_DEFAULT_MEM_ORDER_unknown);
+  case OMPC_fail:
+    return static_cast<unsigned int>(llvm::StringSwitch<llvm::omp::Clause>(Str)
+#define OPENMP_ATOMIC_FAIL_MODIFIER(Name) .Case(#Name, OMPC_##Name)
+#include "clang/Basic/OpenMPKinds.def"
+                                         .Default(OMPC_unknown));
   case OMPC_device_type:
     return llvm::StringSwitch<OpenMPDeviceType>(Str)
 #define OPENMP_DEVICE_TYPE_KIND(Name) .Case(#Name, OMPC_DEVICE_TYPE_##Name)
@@ -434,6 +439,18 @@
 #include "clang/Basic/OpenMPKinds.def"
     }
     llvm_unreachable("Invalid OpenMP 'depend' clause type");
+  case OMPC_fail: {
+    OpenMPClauseKind CK = static_cast<OpenMPClauseKind>(Type);
+    switch (CK) {
+    case OMPC_unknown:
+      return "unknown";
+#define OPENMP_ATOMIC_FAIL_MODIFIER(Name)                                      \
+  case OMPC_##Name:                                                            \
+    return #Name;
+#include "clang/Basic/OpenMPKinds.def"
+    }
+    llvm_unreachable("Invalid OpenMP 'fail' clause modifier");
+  }
   case OMPC_device:
     switch (Type) {
     case OMPC_DEVICE_unknown:
Index: clang/lib/AST/StmtProfile.cpp
===================================================================
--- clang/lib/AST/StmtProfile.cpp
+++ clang/lib/AST/StmtProfile.cpp
@@ -582,6 +582,8 @@
 
 void OMPClauseProfiler::VisitOMPCompareClause(const OMPCompareClause *) {}
 
+void OMPClauseProfiler::VisitOMPFailClause(const OMPFailClause *) {}
+
 void OMPClauseProfiler::VisitOMPSeqCstClause(const OMPSeqCstClause *) {}
 
 void OMPClauseProfiler::VisitOMPAcqRelClause(const OMPAcqRelClause *) {}
Index: clang/lib/AST/OpenMPClause.cpp
===================================================================
--- clang/lib/AST/OpenMPClause.cpp
+++ clang/lib/AST/OpenMPClause.cpp
@@ -130,6 +130,7 @@
   case OMPC_update:
   case OMPC_capture:
   case OMPC_compare:
+  case OMPC_fail:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
@@ -227,6 +228,7 @@
   case OMPC_update:
   case OMPC_capture:
   case OMPC_compare:
+  case OMPC_fail:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
@@ -422,6 +424,25 @@
   return Clause;
 }
 
+OMPFailClause *OMPFailClause::Create(const ASTContext &C,
+                                     SourceLocation StartLoc,
+                                     SourceLocation EndLoc) {
+  return new (C) OMPFailClause(StartLoc, EndLoc);
+}
+
+OMPFailClause *OMPFailClause::CreateEmpty(const ASTContext &C) {
+  return new (C) OMPFailClause();
+}
+
+OMPFailClause *OMPFailClause::Create(const ASTContext &C,
+		OpenMPClauseKind FailParameter, 
+		SourceLocation ArgumentLoc,
+                SourceLocation StartLoc, SourceLocation LParenLoc,
+		SourceLocation EndLoc) {
+  return new (C) OMPFailClause(FailParameter, ArgumentLoc, StartLoc, LParenLoc,
+		               EndLoc);
+}
+
 void OMPPrivateClause::setPrivateCopies(ArrayRef<Expr *> VL) {
   assert(VL.size() == varlist_size() &&
          "Number of private copies is not the same as the preallocated buffer");
@@ -1923,6 +1944,17 @@
   OS << "compare";
 }
 
+void OMPClausePrinter::VisitOMPFailClause(OMPFailClause *Node) {
+  OS << "fail";
+  if (Node) {
+    OS << "(";
+    OS << getOpenMPSimpleClauseTypeName(
+        Node->getClauseKind(),
+        static_cast<int>(Node->getMemoryOrderClauseKind()));
+    OS << ")";
+  }
+}
+
 void OMPClausePrinter::VisitOMPSeqCstClause(OMPSeqCstClause *) {
   OS << "seq_cst";
 }
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -11320,6 +11320,9 @@
                         OpenMPDirectiveKind &Kind,
                         OpenMPDirectiveKind &PrevMappedDirective);
 
+  /// Check if the Fail Clause parameters are correct.
+  void checkFailClauseParameters(ArrayRef<OMPClause *> Clauses);
+
 public:
   /// The declarator \p D defines a function in the scope \p S which is nested
   /// in an `omp begin/end declare variant` scope. In this method we create a
@@ -12191,6 +12194,13 @@
   /// Called on well-formed 'compare' clause.
   OMPClause *ActOnOpenMPCompareClause(SourceLocation StartLoc,
                                       SourceLocation EndLoc);
+  /// Called on well-formed 'fail' clause.
+  OMPClause *ActOnOpenMPFailClause(SourceLocation StartLoc,
+                                   SourceLocation EndLoc);
+  OMPClause *ActOnOpenMPFailClause(
+      OpenMPClauseKind Kind, SourceLocation KindLoc,
+      SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc);
+
   /// Called on well-formed 'seq_cst' clause.
   OMPClause *ActOnOpenMPSeqCstClause(SourceLocation StartLoc,
                                      SourceLocation EndLoc);
Index: clang/include/clang/Parse/Parser.h
===================================================================
--- clang/include/clang/Parse/Parser.h
+++ clang/include/clang/Parse/Parser.h
@@ -431,6 +431,8 @@
   /// a statement expression and builds a suitable expression statement.
   StmtResult handleExprStmt(ExprResult E, ParsedStmtContext StmtCtx);
 
+  OMPClause *ParseOpenMPFailClause(OMPClause *Clause);
+
 public:
   Parser(Preprocessor &PP, Sema &Actions, bool SkipFunctionBodies);
   ~Parser() override;
Index: clang/include/clang/Basic/OpenMPKinds.def
===================================================================
--- clang/include/clang/Basic/OpenMPKinds.def
+++ clang/include/clang/Basic/OpenMPKinds.def
@@ -41,6 +41,9 @@
 #ifndef OPENMP_ATOMIC_DEFAULT_MEM_ORDER_KIND
 #define OPENMP_ATOMIC_DEFAULT_MEM_ORDER_KIND(Name)
 #endif
+#ifndef OPENMP_ATOMIC_FAIL_MODIFIER
+#define OPENMP_ATOMIC_FAIL_MODIFIER(Name)
+#endif
 #ifndef OPENMP_AT_KIND
 #define OPENMP_AT_KIND(Name)
 #endif
@@ -137,6 +140,13 @@
 OPENMP_ATOMIC_DEFAULT_MEM_ORDER_KIND(acq_rel)
 OPENMP_ATOMIC_DEFAULT_MEM_ORDER_KIND(relaxed)
 
+// Modifiers for atomic 'fail' clause.
+OPENMP_ATOMIC_FAIL_MODIFIER(seq_cst)
+OPENMP_ATOMIC_FAIL_MODIFIER(acquire)
+OPENMP_ATOMIC_FAIL_MODIFIER(acq_rel)
+OPENMP_ATOMIC_FAIL_MODIFIER(relaxed)
+OPENMP_ATOMIC_FAIL_MODIFIER(release)
+
 // Modifiers for 'at' clause.
 OPENMP_AT_KIND(compilation)
 OPENMP_AT_KIND(execution)
@@ -225,6 +235,7 @@
 #undef OPENMP_SCHEDULE_MODIFIER
 #undef OPENMP_SCHEDULE_KIND
 #undef OPENMP_ATOMIC_DEFAULT_MEM_ORDER_KIND
+#undef OPENMP_ATOMIC_FAIL_MODIFIER
 #undef OPENMP_AT_KIND
 #undef OPENMP_SEVERITY_KIND
 #undef OPENMP_MAP_KIND
Index: clang/include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -10946,6 +10946,10 @@
   "expect binary operator in conditional expression|expect '<', '>' or '==' as order operator|expect comparison in a form of 'x == e', 'e == x', 'x ordop expr', or 'expr ordop x'|"
   "expect lvalue for result value|expect scalar value|expect integer value|unexpected 'else' statement|expect '==' operator|expect an assignment statement 'v = x'|"
   "expect a 'if' statement|expect no more than two statements|expect a compound statement|expect 'else' statement|expect a form 'r = x == e; if (r) ...'}0">;
+def err_omp_atomic_fail_wrong_or_no_clauses : Error<"expected a memory order clause">;
+def err_omp_atomic_fail_extra_mem_order_clauses : Error<"directive '#pragma omp atomic compare fail' cannot contain more than one memory order clause">;
+def err_omp_atomic_fail_extra_clauses : Error<"directive '#pragma omp atomic compare' cannot contain more than one fail clause">;
+def err_omp_atomic_fail_no_compare : Error<"expected 'compare' clause with the 'fail' modifier">;
 def err_omp_atomic_several_clauses : Error<
   "directive '#pragma omp atomic' cannot contain more than one 'read', 'write', 'update', 'capture', or 'compare' clause">;
 def err_omp_several_mem_order_clauses : Error<
Index: clang/include/clang/AST/RecursiveASTVisitor.h
===================================================================
--- clang/include/clang/AST/RecursiveASTVisitor.h
+++ clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3398,6 +3398,11 @@
   return true;
 }
 
+template <typename Derived>
+bool RecursiveASTVisitor<Derived>::VisitOMPFailClause(OMPFailClause *) {
+  return true;
+}
+
 template <typename Derived>
 bool RecursiveASTVisitor<Derived>::VisitOMPSeqCstClause(OMPSeqCstClause *) {
   return true;
Index: clang/include/clang/AST/OpenMPClause.h
===================================================================
--- clang/include/clang/AST/OpenMPClause.h
+++ clang/include/clang/AST/OpenMPClause.h
@@ -2513,6 +2513,139 @@
   }
 };
 
+/// This represents 'fail' clause in the '#pragma omp atomic'
+/// directive.
+///
+/// \code
+/// #pragma omp atomic compare fail
+/// \endcode
+/// In this example directive '#pragma omp atomic compare' has 'fail' clause.
+class OMPFailClause final : public OMPClause {
+
+  OMPClause *FailMemoryOrderClause = nullptr;
+  SourceLocation ArgumentLoc;
+  SourceLocation LParenLoc;
+
+  friend class OMPClauseReader;
+
+  /// Sets the location of '(' in fail clause.
+  void setLParenLoc(SourceLocation Loc) {
+    LParenLoc = Loc;
+  }
+
+  /// Sets the location of memoryOrder clause argument in fail clause.
+  void setArgumentLoc(SourceLocation Loc) {
+    ArgumentLoc = Loc;
+  }
+
+  /// Sets the mem_order clause for 'atomic compare fail' directive.
+  void setMemoryOrderClauseKind(OpenMPClauseKind MemOrderKind) {
+
+    switch (MemOrderKind) {
+    case llvm::omp::OMPC_acq_rel:
+    case llvm::omp::OMPC_acquire:
+      FailMemoryOrderClause = new OMPAcquireClause(ArgumentLoc, getEndLoc());
+      break;
+    case llvm::omp::OMPC_relaxed:
+    case llvm::omp::OMPC_release:
+      FailMemoryOrderClause = new OMPRelaxedClause(ArgumentLoc, getEndLoc());
+      break;
+    case llvm::omp::OMPC_seq_cst:
+      FailMemoryOrderClause = new OMPSeqCstClause(ArgumentLoc, getEndLoc());
+      break;
+    default:
+      FailMemoryOrderClause = nullptr;
+      break;
+    }
+  }
+
+  /// Sets the mem_order clause for 'atomic compare fail' directive.
+  void setMemoryOrderClause(OMPClause *MemoryOrderClause) {
+    this->FailMemoryOrderClause = MemoryOrderClause;
+  }
+
+public:
+  /// Build 'fail' clause.
+  ///
+  /// \param StartLoc Starting location of the clause.
+  /// \param EndLoc Ending location of the clause.
+  OMPFailClause(SourceLocation StartLoc, SourceLocation EndLoc)
+      : OMPClause(llvm::omp::OMPC_fail, StartLoc, EndLoc) {}
+
+  OMPFailClause(OpenMPClauseKind FailParameter, SourceLocation ArgumentLoc,
+                SourceLocation StartLoc, SourceLocation LParenLoc,
+                SourceLocation EndLoc)
+      : OMPClause(llvm::omp::OMPC_fail, StartLoc, EndLoc),
+        ArgumentLoc(ArgumentLoc), LParenLoc(LParenLoc) {
+
+    setMemoryOrderClauseKind(FailParameter);
+  }
+
+  /// Build an empty clause.
+  OMPFailClause()
+      : OMPClause(llvm::omp::OMPC_fail, SourceLocation(), SourceLocation()) {}
+
+  static OMPFailClause *CreateEmpty(const ASTContext &C);
+  static OMPFailClause *Create(const ASTContext &C, SourceLocation StartLoc,
+                               SourceLocation EndLoc);
+  static OMPFailClause *Create(const ASTContext &C,
+                               OpenMPClauseKind FailParameter,
+                               SourceLocation ArgumentLoc,
+                               SourceLocation StartLoc,
+                               SourceLocation LParenLoc, SourceLocation EndLoc);
+
+  child_range children() {
+    return child_range(child_iterator(), child_iterator());
+  }
+
+  const_child_range children() const {
+    return const_child_range(const_child_iterator(), const_child_iterator());
+  }
+
+  child_range used_children() {
+    return child_range(child_iterator(), child_iterator());
+  }
+  const_child_range used_children() const {
+    return const_child_range(const_child_iterator(), const_child_iterator());
+  }
+
+  static bool classof(const OMPClause *T) {
+    return T->getClauseKind() == llvm::omp::OMPC_fail;
+  }
+
+  void initFailClause(SourceLocation LParenLoc, OMPClause *MemOClause,
+                      SourceLocation MemOrderLoc) {
+
+    setLParenLoc(LParenLoc);
+    setArgumentLoc(MemOrderLoc);
+
+    OpenMPClauseKind ClauseKind = (MemOClause == nullptr) ? llvm::omp::OMPC_unknown : MemOClause->getClauseKind();
+    setMemoryOrderClauseKind(ClauseKind);
+  }
+
+  /// Gets the location of '(' in fail clause.
+  SourceLocation getLParenLoc() const {
+    return LParenLoc;
+  }
+
+  OMPClause *getMemoryOrderClause() { return FailMemoryOrderClause; }
+
+  const OMPClause *getMemoryOrderClause() const {
+    return static_cast<const OMPClause *>(FailMemoryOrderClause);
+  }
+
+  /// Gets the location of memoryOrder clause argument in fail clause.
+  SourceLocation getArgumentLoc() const {
+    return ArgumentLoc;
+  }
+
+  /// Gets the dependence kind in clause for 'depobj' directive.
+  OpenMPClauseKind getMemoryOrderClauseKind() const {
+    OpenMPClauseKind CK = (FailMemoryOrderClause == nullptr) ? llvm::omp::OMPC_unknown:FailMemoryOrderClause->getClauseKind();
+    return CK;
+  }
+};
+
 /// This represents clause 'private' in the '#pragma omp ...' directives.
 ///
 /// \code
Index: clang/include/clang/AST/ASTNodeTraverser.h
===================================================================
--- clang/include/clang/AST/ASTNodeTraverser.h
+++ clang/include/clang/AST/ASTNodeTraverser.h
@@ -214,6 +214,14 @@
   }
 
   void Visit(const OMPClause *C) {
+    if (const auto *OMPC = dyn_cast<OMPFailClause>(C)) {
+      // Parameter of the OMPFailClause is a MemoryClause.
+      // e.g. in case of -ast-dump to see this parameter of the FailClause
+      // we have a special logic to call Visit(const OMPFailClause *C).
+      Visit(OMPC);
+      return;
+    }
+
     getNodeDelegate().AddChild([=] {
       getNodeDelegate().Visit(C);
       for (const auto *S : C->children())
@@ -221,6 +229,14 @@
     });
   }
 
+  void Visit(const OMPFailClause *C) {
+    getNodeDelegate().AddChild([=] {
+      getNodeDelegate().Visit(C);
+      const OMPClause *MOC = C->getMemoryOrderClause();
+      Visit(MOC);
+    });
+  }
+
   void Visit(const GenericSelectionExpr::ConstAssociation &A) {
     getNodeDelegate().AddChild([=] {
       getNodeDelegate().Visit(A);
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to