abidmalikwaterloo updated this revision to Diff 423305.
abidmalikwaterloo added a comment.

Cleaned the code and added tests.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D122255/new/

https://reviews.llvm.org/D122255

Files:
  clang/include/clang/AST/OpenMPClause.h
  clang/include/clang/AST/RecursiveASTVisitor.h
  clang/include/clang/AST/StmtOpenMP.h
  clang/include/clang/Basic/DiagnosticSemaKinds.td
  clang/include/clang/Parse/Parser.h
  clang/include/clang/Sema/Sema.h
  clang/lib/AST/OpenMPClause.cpp
  clang/lib/AST/StmtPrinter.cpp
  clang/lib/Parse/ParseOpenMP.cpp
  clang/lib/Sema/SemaOpenMP.cpp
  clang/lib/Sema/SemaStmt.cpp
  clang/test/OpenMP/metadirective_ast_print_new_1.cpp
  clang/test/OpenMP/metadirective_ast_print_new_2.cpp
  clang/test/OpenMP/metadirective_ast_print_new_3.cpp
  llvm/include/llvm/Frontend/OpenMP/OMPContext.h
  llvm/lib/Frontend/OpenMP/OMPContext.cpp

Index: llvm/lib/Frontend/OpenMP/OMPContext.cpp
===================================================================
--- llvm/lib/Frontend/OpenMP/OMPContext.cpp
+++ llvm/lib/Frontend/OpenMP/OMPContext.cpp
@@ -20,6 +20,7 @@
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 
+#include <map>
 #define DEBUG_TYPE "openmp-ir-builder"
 
 using namespace llvm;
@@ -339,6 +340,41 @@
   return Score;
 }
 
+/// Takes \p VMI and \p Ctx and sort the 
+/// scores using \p A
+void llvm::omp::getArrayVariantMatchForContext(const SmallVectorImpl<VariantMatchInfo> &VMIs,
+                                    const OMPContext &Ctx, SmallVector<std::pair<unsigned, APInt>> &A){
+	
+    //APInt BestScore(64, 0);
+    APInt Score (64, 0);
+    llvm::DenseMap<unsigned, llvm::APInt> m;
+		
+    for (unsigned u = 0, e = VMIs.size(); u < e; ++u) {
+      const VariantMatchInfo &VMI = VMIs[u];
+
+      SmallVector<unsigned, 8> ConstructMatches;
+      // If the variant is not applicable its not the best.
+      if (!isVariantApplicableInContextHelper(VMI, Ctx, &ConstructMatches,
+                                           /* DeviceSetOnly */ false)){
+        Score = 0;
+        m.insert({u, Score});                    
+      	continue;	}
+   	// Check if its clearly not the best.
+    	Score = getVariantMatchScore(VMI, Ctx, ConstructMatches);
+	m.insert({u, Score});	
+	}
+			
+    for (auto& it : m) 
+      A.push_back(it);
+	
+    std::sort(A.begin(), A.end(), [] (std::pair<unsigned, APInt>&a,
+    	std::pair<unsigned, APInt>&b){	
+	return a.second.ugt(b.second);
+    });	
+}
+ 
+
+
 int llvm::omp::getBestVariantMatchForContext(
     const SmallVectorImpl<VariantMatchInfo> &VMIs, const OMPContext &Ctx) {
 
Index: llvm/include/llvm/Frontend/OpenMP/OMPContext.h
===================================================================
--- llvm/include/llvm/Frontend/OpenMP/OMPContext.h
+++ llvm/include/llvm/Frontend/OpenMP/OMPContext.h
@@ -189,6 +189,15 @@
 int getBestVariantMatchForContext(const SmallVectorImpl<VariantMatchInfo> &VMIs,
                                   const OMPContext &Ctx);
 
+/// Sort array \p A of clause index  with score
+/// This will be used to produce AST clauses
+/// in a sorted order with the clause with the highiest order
+/// on the top and default clause at the bottom
+void getArrayVariantMatchForContext(
+    const SmallVectorImpl<VariantMatchInfo> &VMIs, const OMPContext &Ctx,
+    SmallVector<std::pair<unsigned, APInt>> &A);
+
+// new--
 } // namespace omp
 
 template <> struct DenseMapInfo<omp::TraitProperty> {
Index: clang/test/OpenMP/metadirective_ast_print_new_3.cpp
===================================================================
--- /dev/null
+++ clang/test/OpenMP/metadirective_ast_print_new_3.cpp
@@ -0,0 +1,22 @@
+// RUN: %clang_cc1 -verify  -fopenmp  -ast-print %s -o - | FileCheck %s
+// expected-no-diagnostics
+
+int main() {
+  int N = 15;
+#pragma omp metadirective when(user = {condition(N > 10)} : parallel for)\
+ 				default(target teams) 
+  for (int i = 0; i < N; i++)
+    ;
+
+
+#pragma omp metadirective when(device = {arch("nvptx64")}, user = {condition(N >= 100)} : parallel for)\
+  				default(target parallel for)
+  for (int i = 0; i < N; i++)
+    ;
+  return 0;
+}
+
+
+
+// CHECK: #pragma omp metadirective when(user={condition(N > 10)}: parallel for) default(target teams)
+// CHECK: #pragma omp metadirective when(device={arch(nvptx64)}, user={condition(N >= 100)}: parallel for) default(target parallel for)
Index: clang/test/OpenMP/metadirective_ast_print_new_2.cpp
===================================================================
--- /dev/null
+++ clang/test/OpenMP/metadirective_ast_print_new_2.cpp
@@ -0,0 +1,29 @@
+// RUN: %clang_cc1 -verify  -fopenmp  -ast-print %s -o - | FileCheck %s
+// expected-no-diagnostics
+
+void bar(){
+	int i=0;	
+}
+
+void myfoo(void){
+
+	int N = 13;
+	int b,n;
+	int a[100];
+
+	
+	#pragma omp  metadirective when (user = {condition(N>10)}: target  teams distribute parallel for ) \
+					when (user = {condition(N==10)}: parallel for )\
+					when (user = {condition(N==13)}: parallel for simd) \
+					when ( device={arch("arm")}:   target teams num_teams(512) thread_limit(32))\
+					when ( device={arch("nvptx")}: target teams num_teams(512) thread_limit(32))\
+					default ( parallel for)\
+
+	{		for (int i = 0; i < N; i++)
+		bar();
+	}
+}
+
+// CHECK: bar()
+// CHECK: myfoo
+// CHECK: #pragma omp metadirective when(user={condition(N > 10)}: target teams distribute parallel for) when(user={condition(N == 13)}: parallel for simd) when(device={arch(nvptx)}: target teams)
Index: clang/test/OpenMP/metadirective_ast_print_new_1.cpp
===================================================================
--- /dev/null
+++ clang/test/OpenMP/metadirective_ast_print_new_1.cpp
@@ -0,0 +1,20 @@
+// RUN: %clang_cc1 -verify  -fopenmp  -ast-print %s -o - | FileCheck %s
+// expected-no-diagnostics
+void bar(){
+        int i=0;
+}
+
+void myfoo(void){
+
+        int N = 13;
+        int b,n;
+        int a[100];
+
+        #pragma omp metadirective when(user={condition(N>10)}:  target teams ) default(parallel for)
+                for (int i = 0; i < N; i++)
+                bar();
+
+}
+
+// CHECK: void bar()
+// CHECK: #pragma omp metadirective when(user={condition(N > 10)}: target teams) default(parallel for)
Index: clang/lib/Sema/SemaStmt.cpp
===================================================================
--- clang/lib/Sema/SemaStmt.cpp
+++ clang/lib/Sema/SemaStmt.cpp
@@ -4791,8 +4791,8 @@
   CapturedStmt *Res = CapturedStmt::Create(
       getASTContext(), S, static_cast<CapturedRegionKind>(RSI->CapRegionKind),
       Captures, CaptureInits, CD, RD);
-
-  CD->setBody(Res->getCapturedStmt());
+  		
+  CD->setBody(Res->getCapturedStmt());   
   RD->completeDefinition();
 
   return Res;
Index: clang/lib/Sema/SemaOpenMP.cpp
===================================================================
--- clang/lib/Sema/SemaOpenMP.cpp
+++ clang/lib/Sema/SemaOpenMP.cpp
@@ -37,6 +37,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Frontend/OpenMP/OMPConstants.h"
+#include "llvm/Frontend/OpenMP/OMPContext.h"
 #include <set>
 
 using namespace clang;
@@ -3930,6 +3931,7 @@
 
 void Sema::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind, Scope *CurScope) {
   switch (DKind) {
+  case OMPD_metadirective:
   case OMPD_parallel:
   case OMPD_parallel_for:
   case OMPD_parallel_for_simd:
@@ -4339,7 +4341,7 @@
   case OMPD_declare_variant:
   case OMPD_begin_declare_variant:
   case OMPD_end_declare_variant:
-  case OMPD_metadirective:
+  //case OMPD_metadirective:
     llvm_unreachable("OpenMP Directive is not allowed");
   case OMPD_unknown:
   default:
@@ -4522,6 +4524,7 @@
 
 StmtResult Sema::ActOnOpenMPRegionEnd(StmtResult S,
                                       ArrayRef<OMPClause *> Clauses) {
+                                                                        
   handleDeclareVariantConstructTrait(DSAStack, DSAStack->getCurrentDirective(),
                                      /* ScopeEntry */ false);
   if (DSAStack->getCurrentDirective() == OMPD_atomic ||
@@ -4590,6 +4593,7 @@
     else if (Clause->getClauseKind() == OMPC_linear)
       LCs.push_back(cast<OMPLinearClause>(Clause));
   }
+  
   // Capture allocator expressions if used.
   for (Expr *E : DSAStack->getInnerAllocators())
     MarkDeclarationsReferencedInExpr(E);
@@ -4611,6 +4615,7 @@
         << SourceRange(OC->getBeginLoc(), OC->getEndLoc());
     ErrorFound = true;
   }
+  
   // OpenMP 5.0, 2.9.2 Worksharing-Loop Construct, Restrictions.
   // If an order(concurrent) clause is present, an ordered clause may not appear
   // on the same directive.
@@ -4623,6 +4628,7 @@
     }
     ErrorFound = true;
   }
+  
   if (isOpenMPWorksharingDirective(DSAStack->getCurrentDirective()) &&
       isOpenMPSimdDirective(DSAStack->getCurrentDirective()) && OC &&
       OC->getNumForLoops()) {
@@ -4635,7 +4641,9 @@
   }
   StmtResult SR = S;
   unsigned CompletedRegions = 0;
+  
   for (OpenMPDirectiveKind ThisCaptureRegion : llvm::reverse(CaptureRegions)) {
+  
     // Mark all variables in private list clauses as used in inner region.
     // Required for proper codegen of combined directives.
     // TODO: add processing for other clauses.
@@ -4656,6 +4664,7 @@
         }
       }
     }
+    
     if (ThisCaptureRegion == OMPD_target) {
       // Capture allocator traits in the target region. They are used implicitly
       // and, thus, are not captured by default.
@@ -4671,6 +4680,7 @@
         }
       }
     }
+    
     if (ThisCaptureRegion == OMPD_parallel) {
       // Capture temp arrays for inscan reductions and locals in aligned
       // clauses.
@@ -4687,10 +4697,14 @@
         }
       }
     }
+    
     if (++CompletedRegions == CaptureRegions.size())
       DSAStack->setBodyComplete();
+    
     SR = ActOnCapturedRegionEnd(SR.get());
+    
   }
+  
   return SR;
 }
 
@@ -5963,6 +5977,12 @@
 
   llvm::SmallVector<OpenMPDirectiveKind, 4> AllowedNameModifiers;
   switch (Kind) {
+
+  case OMPD_metadirective:
+    Res = ActOnOpenMPMetaDirective(ClausesWithImplicit, AStmt, StartLoc, 
+    					EndLoc);
+    AllowedNameModifiers.push_back(OMPD_metadirective); 
+    break;
   case OMPD_parallel:
     Res = ActOnOpenMPParallelDirective(ClausesWithImplicit, AStmt, StartLoc,
                                        EndLoc);
@@ -7342,6 +7362,116 @@
   FD->addAttr(NewAttr);
 }
 
+StmtResult Sema::ActOnOpenMPMetaDirective(ArrayRef<OMPClause *> Clauses,
+                                              Stmt *AStmt,
+                                              SourceLocation StartLoc,
+                                              SourceLocation EndLoc) {
+                                              
+ if (!AStmt)
+    return StmtError();
+
+  auto *CS = cast<CapturedStmt>(AStmt);
+  
+  CS->getCapturedDecl()->setNothrow();
+
+  StmtResult IfStmt = StmtError();
+  Stmt *ElseStmt = NULL;
+
+  for (auto i = Clauses.rbegin(); i < Clauses.rend(); i++) {
+    OMPWhenClause *WhenClause = dyn_cast<OMPWhenClause>(*i);
+    Expr *WhenCondExpr = NULL;
+    Stmt *ThenStmt = NULL;
+    OpenMPDirectiveKind DKind = WhenClause->getDKind();
+
+    if (DKind != OMPD_unknown)
+      ThenStmt = CompoundStmt::Create(Context, {WhenClause->getDirective()},
+                                      SourceLocation(), SourceLocation());
+
+    for (const OMPTraitSet &Set : WhenClause->getTI().Sets) {
+      for (const OMPTraitSelector &Selector : Set.Selectors) {
+        switch (Selector.Kind) {
+        case TraitSelector::device_arch: {
+          bool archMatch = false;
+          for (const OMPTraitProperty &Property : Selector.Properties) {
+            for (auto &T : getLangOpts().OMPTargetTriples) {
+              if (T.getArchName() == Property.RawString) {
+                archMatch = true;
+                break;
+              }
+            }
+            if (archMatch)
+              break;
+          }
+          // Create a true/false boolean expression and assign to WhenCondExpr
+          auto *C = new (Context)
+              CXXBoolLiteralExpr(archMatch, Context.BoolTy, StartLoc);
+          WhenCondExpr = dyn_cast<Expr>(C);
+          break;
+        }
+        case TraitSelector::user_condition: {
+          assert(Selector.ScoreOrCondition &&
+                 "Ill-formed user condition, expected condition expression!");
+
+          WhenCondExpr = Selector.ScoreOrCondition;
+          break;
+        }
+        case TraitSelector::implementation_vendor: {
+          bool vendorMatch = false;
+          for (const OMPTraitProperty &Property : Selector.Properties) {
+            for (auto &T : getLangOpts().OMPTargetTriples) {
+              if (T.getVendorName() == Property.RawString) {
+                vendorMatch = true;
+                break;
+              }
+            }
+            if (vendorMatch)
+              break;
+          }
+          // Create a true/false boolean expression and assign to WhenCondExpr
+          auto *C = new (Context)
+              CXXBoolLiteralExpr(vendorMatch, Context.BoolTy, StartLoc);
+          WhenCondExpr = dyn_cast<Expr>(C);
+          break;
+        }
+        case TraitSelector::device_isa:
+        case TraitSelector::device_kind:
+        case TraitSelector::implementation_extension:
+        default:
+          break;
+        }
+      }
+    }
+
+    if (WhenCondExpr == NULL) {
+      if (ElseStmt != NULL) {
+        Diag(WhenClause->getBeginLoc(), diag::err_omp_misplaced_default_clause);
+        return StmtError();
+      }
+      if (DKind == OMPD_unknown)
+        ElseStmt = CompoundStmt::Create(Context, {CS->getCapturedStmt()},
+                                        SourceLocation(), SourceLocation());
+      else
+        ElseStmt = ThenStmt;
+      continue;
+    }
+
+    if (ThenStmt == NULL)
+      ThenStmt = CompoundStmt::Create(Context, {CS->getCapturedStmt()},
+                                      SourceLocation(), SourceLocation());
+
+    IfStmt =
+        ActOnIfStmt(SourceLocation(), /*false*/ IfStatementKind::Ordinary, SourceLocation(), NULL,
+                    ActOnCondition(getCurScope(), SourceLocation(),
+                                   WhenCondExpr, Sema::ConditionKind::Boolean),
+                    SourceLocation(), ThenStmt, SourceLocation(), ElseStmt);
+    ElseStmt = IfStmt.get();
+  }
+
+  return OMPMetaDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
+                                  IfStmt.get());
+                                              
+}
+
 StmtResult Sema::ActOnOpenMPParallelDirective(ArrayRef<OMPClause *> Clauses,
                                               Stmt *AStmt,
                                               SourceLocation StartLoc,
@@ -14837,6 +14967,17 @@
   return std::string(Out.str());
 }
 
+OMPClause *
+Sema::ActOnOpenMPWhenClause(OMPTraitInfo &TI, OpenMPDirectiveKind DKind,
+                            StmtResult Directive, SourceLocation StartLoc,
+                            SourceLocation LParenLoc, SourceLocation EndLoc) {
+  return new (Context)
+      OMPWhenClause(TI, DKind, Directive.get(), StartLoc, LParenLoc, EndLoc);
+}
+
+
+
+
 OMPClause *Sema::ActOnOpenMPDefaultClause(DefaultKind Kind,
                                           SourceLocation KindKwLoc,
                                           SourceLocation StartLoc,
Index: clang/lib/Parse/ParseOpenMP.cpp
===================================================================
--- clang/lib/Parse/ParseOpenMP.cpp
+++ clang/lib/Parse/ParseOpenMP.cpp
@@ -2430,6 +2430,7 @@
 ///
 StmtResult
 Parser::ParseOpenMPDeclarativeOrExecutableDirective(ParsedStmtContext StmtCtx) {
+// need to check about the following
   static bool ReadDirectiveWithinMetadirective = false;
   if (!ReadDirectiveWithinMetadirective)
     assert(Tok.isOneOf(tok::annot_pragma_openmp, tok::annot_attr_openmp) &&
@@ -2470,10 +2471,13 @@
 
     BalancedDelimiterTracker T(*this, tok::l_paren,
                                tok::annot_pragma_openmp_end);
+                               
     while (Tok.isNot(tok::annot_pragma_openmp_end)) {
+  
       OpenMPClauseKind CKind = Tok.isAnnotation()
                                    ? OMPC_unknown
                                    : getOpenMPClauseKind(PP.getSpelling(Tok));
+                                   
       SourceLocation Loc = ConsumeToken();
 
       // Parse '('.
@@ -2491,7 +2495,7 @@
           return Directive;
         }
 
-        // Parse ':'
+        // Parse ':' // You have parsed the OpenMP Context in the meta directive 
         if (Tok.is(tok::colon))
           ConsumeAnyToken();
         else {
@@ -2499,8 +2503,10 @@
           TPA.Commit();
           return Directive;
         }
-      }
+      } // if (CKind == OMPC_when)  statement ends
+     
       // Skip Directive for now. We will parse directive in the second iteration
+      // This need to be catched
       int paren = 0;
       while (Tok.isNot(tok::r_paren) || paren != 0) {
         if (Tok.is(tok::l_paren))
@@ -2513,86 +2519,105 @@
           TPA.Commit();
           return Directive;
         }
-        ConsumeAnyToken();
-      }
+        ConsumeAnyToken();  
+      } // end of the while statement while (Tok.isNot(tok::r_paren)
+      
       // Parse ')'
       if (Tok.is(tok::r_paren))
         T.consumeClose();
-
+      
       VariantMatchInfo VMI;
       TI.getAsVariantMatchInfo(ASTContext, VMI);
-
-      VMIs.push_back(VMI);
-    }
-
+      
+      if (CKind == OMPC_when )	
+     		 VMIs.push_back(VMI);
+    } // end of while (Tok.isNot(tok::annot_pragma_openmp_end))
+    
+    // This is the end of the first iteration
+    // The pointer is moved back
     TPA.Revert();
     // End of the first iteration. Parser is reset to the start of metadirective
-
+    
     TargetOMPContext OMPCtx(ASTContext, /* DiagUnknownTrait */ nullptr,
                             /* CurrentFunctionDecl */ nullptr,
                             ArrayRef<llvm::omp::TraitProperty>());
-
-    // A single match is returned for OpenMP 5.0
-    int BestIdx = getBestVariantMatchForContext(VMIs, OMPCtx);
-
-    int Idx = 0;
-    // In OpenMP 5.0 metadirective is either replaced by another directive or
-    // ignored.
-    // TODO: In OpenMP 5.1 generate multiple directives based upon the matches
-    // found by getBestWhenMatchForContext.
-    while (Tok.isNot(tok::annot_pragma_openmp_end)) {
-      // OpenMP 5.0 implementation - Skip to the best index found.
-      if (Idx++ != BestIdx) {
-        ConsumeToken();  // Consume clause name
-        T.consumeOpen(); // Consume '('
-        int paren = 0;
-        // Skip everything inside the clause
-        while (Tok.isNot(tok::r_paren) || paren != 0) {
-          if (Tok.is(tok::l_paren))
-            paren++;
-          if (Tok.is(tok::r_paren))
-            paren--;
-          ConsumeAnyToken();
-        }
-        // Parse ')'
-        if (Tok.is(tok::r_paren))
-          T.consumeClose();
-        continue;
-      }
-
+    
+    // Array A will be used for sorting                        
+    SmallVector<std::pair<unsigned, llvm::APInt>> A;
+    
+    // The function will get the score for each clause and sort it 
+    // based on the score number
+    
+    getArrayVariantMatchForContext(VMIs, OMPCtx, A) ;          
+                              	    
+    ParseScope OMPDirectiveScope(this, ScopeFlags);	
+    Actions.StartOpenMPDSABlock(DKind, DirName, Actions.getCurScope(), Loc);  
+    
+   while(Tok.isNot(tok::annot_pragma_openmp_end)){
+   	
       OpenMPClauseKind CKind = Tok.isAnnotation()
                                    ? OMPC_unknown
                                    : getOpenMPClauseKind(PP.getSpelling(Tok));
-      SourceLocation Loc = ConsumeToken();
-
-      // Parse '('.
-      T.consumeOpen();
-
-      // Skip ContextSelectors for when clause
-      if (CKind == OMPC_when) {
-        OMPTraitInfo &TI = Actions.getASTContext().getNewOMPTraitInfo();
-        // parse and skip the ContextSelectors
-        parseOMPContextSelectors(Loc, TI);
-
-        // Parse ':'
-        ConsumeAnyToken();
-      }
-
-      // If no directive is passed, skip in OpenMP 5.0.
-      // TODO: Generate nothing directive from OpenMP 5.1.
-      if (Tok.is(tok::r_paren)) {
-        SkipUntil(tok::annot_pragma_openmp_end);
-        break;
-      }
-
-      // Parse Directive
-      ReadDirectiveWithinMetadirective = true;
-      Directive = ParseOpenMPDeclarativeOrExecutableDirective(StmtCtx);
-      ReadDirectiveWithinMetadirective = false;
-      break;
-    }
-    break;
-  }
+                                   
+      Actions.StartOpenMPClause(CKind);
+      OMPClause *Clause = ParseOpenMPMetaDirectiveClause( DKind, CKind);
+      
+      FirstClauses[(unsigned) CKind].setInt(true);
+      if (Clause) {
+      	FirstClauses[(unsigned) CKind].setPointer(Clause);
+      	Clauses.push_back(Clause);
+      }// end of if statement	
+      
+      if (Tok.is(tok::comma))
+      	ConsumeToken();
+      
+      Actions.EndOpenMPClause();
+      
+      if (Tok.is(tok::r_paren))
+      	ConsumeAnyToken();	
+      		
+   }// end of the while loop
+   
+   // End location of the directive
+   EndLoc = Tok.getLocation();
+   
+   //Consume final annot_pragma_openmp_end
+   ConsumeAnnotationToken();
+   
+   SmallVector<OMPClause *, 5> Clauses_new;
+   unsigned count = 0;
+         
+   for ( auto &it1 : A){
+   	count = 0;
+   	for ( auto &it2 : Clauses){
+   		if ( count == it1.first ){
+   			Clauses_new.push_back(it2);
+   			break;
+   		} else count++;
+   	}
+   }// end of the for loop
+   
+   Clauses_new.push_back(Clauses.back());
+   
+   // Parsing the OpenMP region which will take the
+   // metadirective
+   
+   Actions.ActOnOpenMPRegionStart(DKind, getCurScope());
+   ParsingOpenMPDirectiveRAII NormalScope(*this, /*value=*/ false);
+   // This is parsing the region
+   StmtResult AStmt = ParseStatement();
+   
+   StmtResult AssociatedStmt = (Sema::CompoundScopeRAII(Actions), AStmt);
+   // Ending of the parallel region
+   AssociatedStmt = Actions.ActOnOpenMPRegionEnd(AssociatedStmt, Clauses_new);
+   Directive = Actions.ActOnOpenMPExecutableDirective(
+   	DKind, DirName, CancelRegion, Clauses_new, AssociatedStmt.get(), Loc, 
+   	EndLoc);
+   // Exit scope
+   Actions.EndOpenMPDSABlock(Directive.get());
+   OMPDirectiveScope.Exit();
+   break;
+  } // end of case OMPD_metadirective:
   case OMPD_threadprivate: {
     // FIXME: Should this be permitted in C++?
     if ((StmtCtx & ParsedStmtContext::AllowDeclarationsInC) ==
@@ -3050,6 +3075,164 @@
   return Actions.ActOnOpenMPUsesAllocatorClause(Loc, T.getOpenLocation(),
                                                 T.getCloseLocation(), Data);
 }
+/// Parsing of OpenMP MetaDirective Clauses
+
+OMPClause *Parser::ParseOpenMPMetaDirectiveClause(OpenMPDirectiveKind DKind,
+                                         OpenMPClauseKind CKind) {
+  OMPClause *Clause = nullptr;
+  bool ErrorFound = false;
+  bool WrongDirective = false;
+  SmallVector<llvm::PointerIntPair<OMPClause *, 1, bool>,
+              llvm::omp::Clause_enumSize + 1>
+      FirstClauses(llvm::omp::Clause_enumSize + 1);
+
+  // Check if it is called from metadirective.
+  if (DKind != OMPD_metadirective) {
+    Diag(Tok, diag::err_omp_unexpected_clause)
+        << getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind);
+    ErrorFound = true;
+  }
+
+  // Check if clause is allowed for the given directive.
+  if (CKind != OMPC_unknown &&
+      !isAllowedClauseForDirective(DKind, CKind, getLangOpts().OpenMP)) {
+    Diag(Tok, diag::err_omp_unexpected_clause)
+        << getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind);
+    ErrorFound = true;
+    WrongDirective = true;
+  }
+
+ // Check if clause is not allowed
+ if (CKind == OMPC_unknown) {
+    Diag(Tok, diag::err_omp_unexpected_clause)
+        << getOpenMPClauseName(CKind) << "Unknown clause: Not allowed";
+    ErrorFound = true;
+    WrongDirective = true;
+  }
+	
+  if (CKind == OMPC_default || CKind == OMPC_when) {
+    SourceLocation Loc = ConsumeToken();
+    SourceLocation DelimLoc;
+    // Parse '('.
+    BalancedDelimiterTracker T(*this, tok::l_paren,
+                               tok::annot_pragma_openmp_end);
+    if (T.expectAndConsume(diag::err_expected_lparen_after,
+                           getOpenMPClauseName(CKind).data()))
+      return nullptr;
+
+    OMPTraitInfo &TI = Actions.getASTContext().getNewOMPTraitInfo();
+    if (CKind == OMPC_when) {
+      // parse and get condition expression to pass to the When clause
+      parseOMPContextSelectors(Loc, TI);
+
+      // Parse ':'
+      if (Tok.is(tok::colon))
+        ConsumeAnyToken();
+      else {
+        Diag(Tok, diag::warn_pragma_expected_colon) << "when clause";
+        return nullptr;
+      }
+    }
+
+    // Parse Directive
+    OpenMPDirectiveKind DirKind = OMPD_unknown;
+    SmallVector<OMPClause *, 5> Clauses;
+    StmtResult AssociatedStmt;
+    StmtResult Directive = StmtError();
+
+    if (Tok.isNot(tok::r_paren)) {
+      ParsingOpenMPDirectiveRAII DirScope(*this);
+      ParenBraceBracketBalancer BalancerRAIIObj(*this);
+      DeclarationNameInfo DirName;
+      unsigned ScopeFlags = Scope::FnScope | Scope::DeclScope |
+                            Scope::CompoundStmtScope |
+                            Scope::OpenMPDirectiveScope;
+
+      DirKind = parseOpenMPDirectiveKind(*this);
+      ConsumeToken();
+      ParseScope OMPDirectiveScope(this, ScopeFlags);
+      Actions.StartOpenMPDSABlock(DirKind, DirName, Actions.getCurScope(), Loc);
+      
+      int paren = 0;
+      
+      while (Tok.isNot(tok::r_paren) || paren != 0) {
+        if (Tok.is(tok::l_paren))
+          paren++;
+        if (Tok.is(tok::r_paren))
+          paren--;
+	
+        OpenMPClauseKind CKind = Tok.isAnnotation()
+                                     ? OMPC_unknown
+                                     : getOpenMPClauseKind(PP.getSpelling(Tok));
+         
+        if (CKind == OMPC_unknown &&
+      		!isAllowedClauseForDirective(DirKind, CKind, getLangOpts().OpenMP)) {
+    		Diag(Tok, diag::err_omp_unexpected_clause)
+        		<< getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind);
+    		ErrorFound = true;
+    		WrongDirective = true;
+  	}
+                               
+        Actions.StartOpenMPClause(CKind);
+        OMPClause *Clause = ParseOpenMPClause(
+            DirKind, CKind, !FirstClauses[(unsigned)CKind].getInt());
+        FirstClauses[(unsigned)CKind].setInt(true);
+        if (Clause) {
+          FirstClauses[(unsigned)CKind].setPointer(Clause);
+          Clauses.push_back(Clause);
+        }
+
+        // Skip ',' if any.
+        if (Tok.is(tok::comma))
+          ConsumeToken();
+        Actions.EndOpenMPClause();
+      }
+	
+      Actions.ActOnOpenMPRegionStart(DirKind, getCurScope());
+      ParsingOpenMPDirectiveRAII NormalScope(*this, /*Value=*/false);
+
+      /* Get Stmt and revert back */
+      TentativeParsingAction TPA(*this);
+      while (Tok.isNot(tok::annot_pragma_openmp_end)) {
+        ConsumeAnyToken();
+      }
+      
+      ConsumeAnnotationToken();
+      ParseScope InnerStmtScope(this, Scope::DeclScope,
+                                getLangOpts().C99 || getLangOpts().CPlusPlus,
+                                Tok.is(tok::l_brace));
+                                                      
+      StmtResult AStmt = ParseStatement();
+      InnerStmtScope.Exit();
+      TPA.Revert();
+      /* End Get Stmt */
+
+      AssociatedStmt = (Sema::CompoundScopeRAII(Actions), AStmt);
+      AssociatedStmt = Actions.ActOnOpenMPRegionEnd(AssociatedStmt, Clauses);
+	
+      Directive = Actions.ActOnOpenMPExecutableDirective(
+          DirKind, DirName, OMPD_unknown, llvm::makeArrayRef(Clauses),
+          AssociatedStmt.get(), Loc, Tok.getLocation());
+	
+      Actions.EndOpenMPDSABlock(Directive.get());
+      OMPDirectiveScope.Exit();
+    }
+    // Parse ')'
+    T.consumeClose();
+
+    if (WrongDirective)
+      return nullptr;
+
+    Clause = Actions.ActOnOpenMPWhenClause(TI, DirKind, Directive, Loc,
+                                           DelimLoc, Tok.getLocation());
+  } else {
+    ErrorFound = false;
+    Diag(Tok, diag::err_omp_unexpected_clause)
+        << getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind);     
+  }
+ 		
+  return ErrorFound ? nullptr : Clause;
+}
 
 /// Parsing of OpenMP clauses.
 ///
Index: clang/lib/AST/StmtPrinter.cpp
===================================================================
--- clang/lib/AST/StmtPrinter.cpp
+++ clang/lib/AST/StmtPrinter.cpp
@@ -655,12 +655,25 @@
 
 void StmtPrinter::PrintOMPExecutableDirective(OMPExecutableDirective *S,
                                               bool ForceNoStmt) {
+    
   OMPClausePrinter Printer(OS, Policy);
   ArrayRef<OMPClause *> Clauses = S->clauses();
-  for (auto *Clause : Clauses)
+  for (auto *Clause : Clauses){    
     if (Clause && !Clause->isImplicit()) {
       OS << ' ';
       Printer.Visit(Clause);
+      if (dyn_cast<OMPMetaDirective>(S)){
+     
+      	OMPWhenClause *c = dyn_cast<OMPWhenClause>(Clause);
+      	if (c!=NULL){
+      		if (c->getDKind() != llvm::omp::OMPD_unknown){
+      			Printer.VisitOMPWhenClause(c);
+      			OS << llvm::omp::getOpenMPDirectiveName(c->getDKind());
+      			}
+      		OS << ")";	
+      		}
+      	    }
+    	}	
     }
   OS << NL;
   if (!ForceNoStmt && S->hasAssociatedStmt())
Index: clang/lib/AST/OpenMPClause.cpp
===================================================================
--- clang/lib/AST/OpenMPClause.cpp
+++ clang/lib/AST/OpenMPClause.cpp
@@ -1609,6 +1609,75 @@
 //  OpenMP clauses printing methods
 //===----------------------------------------------------------------------===//
 
+void OMPClausePrinter::VisitOMPWhenClause(OMPWhenClause *Node) {
+
+  if (Node->getTI().Sets.size() == 0) {
+    OS << "default(";
+    return;
+  }
+  OS << "when(";
+  int count = 0;
+  for (const OMPTraitSet &Set : Node->getTI().Sets) {
+    if (count == 0)
+      count++;
+    else
+      OS << ", ";
+    for (const OMPTraitSelector &Selector : Set.Selectors) {
+      switch (Selector.Kind) {
+      case TraitSelector::device_kind: {
+        OS << "device={kind(";
+        for (const OMPTraitProperty &Property : Selector.Properties) {
+          OS << Property.RawString;
+        }
+        OS << ")}";
+        break;
+      }
+      case TraitSelector::device_arch: {
+        OS << "device={arch(";
+        for (const OMPTraitProperty &Property : Selector.Properties) {
+          OS << Property.RawString;
+        }
+        OS << ")}";
+        break;
+      }
+      case TraitSelector::device_isa: {
+        OS << "device={isa(";
+        for (const OMPTraitProperty &Property : Selector.Properties) {
+          OS << Property.RawString;
+        }
+        OS << ")}";
+        break;
+      }
+      case TraitSelector::implementation_vendor: {
+        OS << "implementation={vendor(";
+        for (const OMPTraitProperty &Property : Selector.Properties) {
+          OS << Property.RawString;
+        }
+        OS << ")}";
+        break;
+      }
+      case TraitSelector::implementation_extension: {
+        OS << "implementation={extension(";
+        for (const OMPTraitProperty &Property : Selector.Properties) {
+          OS << Property.RawString;
+        }
+        OS << ")}";
+        break;
+      }
+      case TraitSelector::user_condition: {
+        OS << "user={condition(";
+        Selector.ScoreOrCondition->printPretty(OS, nullptr, Policy, 0);
+        OS << ")}";
+        break;
+      }      
+      default:
+        break;
+      }
+    }
+  }
+  OS << ": ";
+}
+
 void OMPClausePrinter::VisitOMPIfClause(OMPIfClause *Node) {
   OS << "if(";
   if (Node->getNameModifier() != OMPD_unknown)
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -66,6 +66,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TinyPtrVector.h"
 #include "llvm/Frontend/OpenMP/OMPConstants.h"
+#include "llvm/Frontend/OpenMP/OMPContext.h"
 #include <deque>
 #include <memory>
 #include <string>
@@ -10678,10 +10679,16 @@
   ///
   /// \returns Statement for finished OpenMP region.
   StmtResult ActOnOpenMPRegionEnd(StmtResult S, ArrayRef<OMPClause *> Clauses);
+  
+  /// Called on well-formed
   StmtResult ActOnOpenMPExecutableDirective(
       OpenMPDirectiveKind Kind, const DeclarationNameInfo &DirName,
       OpenMPDirectiveKind CancelRegion, ArrayRef<OMPClause *> Clauses,
       Stmt *AStmt, SourceLocation StartLoc, SourceLocation EndLoc);
+  StmtResult ActOnOpenMPExecutableMetaDirective( 
+      OpenMPDirectiveKind Kind, const DeclarationNameInfo &DirName, 
+      OpenMPDirectiveKind CancelRegion, ArrayRef<OMPClause *> Clauses,
+      Stmt *AStmt, SourceLocation StartLoc, SourceLocation EndLoc);
   /// Called on well-formed '\#pragma omp parallel' after parsing
   /// of the  associated statement.
   StmtResult ActOnOpenMPParallelDirective(ArrayRef<OMPClause *> Clauses,
@@ -11127,7 +11134,9 @@
                                      SourceLocation LParenLoc,
                                      SourceLocation EndLoc);
   /// Called on well-formed 'when' clause.
-  OMPClause *ActOnOpenMPWhenClause(OMPTraitInfo &TI, SourceLocation StartLoc,
+  OMPClause *ActOnOpenMPWhenClause(OMPTraitInfo &TI, OpenMPDirectiveKind DKind,
+  				    StmtResult Directive,	 
+  				    SourceLocation StartLoc,
                                    SourceLocation LParenLoc,
                                    SourceLocation EndLoc);
   /// Called on well-formed 'default' clause.
Index: clang/include/clang/Parse/Parser.h
===================================================================
--- clang/include/clang/Parse/Parser.h
+++ clang/include/clang/Parse/Parser.h
@@ -3291,6 +3291,13 @@
   /// \param StmtCtx The context in which we're parsing the directive.
   StmtResult
   ParseOpenMPDeclarativeOrExecutableDirective(ParsedStmtContext StmtCtx);
+  /// Parse clause for metadirective
+  ///
+  /// \param Dkind Kind of current directive
+  /// \param CKind Kind of current clause
+  /// 
+  OMPClause *ParseOpenMPMetaDirectiveClause(OpenMPDirectiveKind DKind,
+                                            OpenMPClauseKind CKind);
   /// Parses clause of kind \a CKind for directive of a kind \a Kind.
   ///
   /// \param DKind Kind of current directive.
Index: clang/include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -10848,6 +10848,9 @@
   "'%0' clause requires 'dispatch' context selector">;
 def err_omp_append_args_with_varargs : Error<
   "'append_args' is not allowed with varargs functions">;
+def err_omp_misplaced_default_clause : Error<
+  "misplaced default clause! Only one default clause is allowed in"
+  "metadirective in the end">;
 } // end of OpenMP category
 
 let CategoryName = "Related Result Type Issue" in {
Index: clang/include/clang/AST/StmtOpenMP.h
===================================================================
--- clang/include/clang/AST/StmtOpenMP.h
+++ clang/include/clang/AST/StmtOpenMP.h
@@ -5476,6 +5476,7 @@
                                   Stmt *AssociatedStmt, Stmt *IfStmt);
   static OMPMetaDirective *CreateEmpty(const ASTContext &C, unsigned NumClauses,
                                        EmptyShell);
+                                       
   Stmt *getIfStmt() const { return IfStmt; }
 
   static bool classof(const Stmt *T) {
Index: clang/include/clang/AST/RecursiveASTVisitor.h
===================================================================
--- clang/include/clang/AST/RecursiveASTVisitor.h
+++ clang/include/clang/AST/RecursiveASTVisitor.h
@@ -501,7 +501,8 @@
   /// Process clauses with pre-initis.
   bool VisitOMPClauseWithPreInit(OMPClauseWithPreInit *Node);
   bool VisitOMPClauseWithPostUpdate(OMPClauseWithPostUpdate *Node);
-
+  bool VisitOMPWhenClause(OMPWhenClause *C);
+  
   bool PostVisitStmt(Stmt *S);
 };
 
@@ -3136,6 +3137,18 @@
   return true;
 }
 
+template <typename Derived>
+bool RecursiveASTVisitor<Derived>::VisitOMPWhenClause(OMPWhenClause *C) {
+  for (const OMPTraitSet &Set : C->getTI().Sets) {
+    for (const OMPTraitSelector &Selector : Set.Selectors) {
+      if (Selector.Kind == llvm::omp::TraitSelector::user_condition &&
+          Selector.ScoreOrCondition)
+        TRY_TO(TraverseStmt(Selector.ScoreOrCondition));
+    }
+  }
+  return true;
+}
+
 template <typename Derived>
 bool RecursiveASTVisitor<Derived>::VisitOMPDefaultClause(OMPDefaultClause *) {
   return true;
Index: clang/include/clang/AST/OpenMPClause.h
===================================================================
--- clang/include/clang/AST/OpenMPClause.h
+++ clang/include/clang/AST/OpenMPClause.h
@@ -8611,25 +8611,7 @@
 template<class ImplClass, typename RetTy = void>
 class ConstOMPClauseVisitor :
       public OMPClauseVisitorBase <ImplClass, const_ptr, RetTy> {};
-
-class OMPClausePrinter final : public OMPClauseVisitor<OMPClausePrinter> {
-  raw_ostream &OS;
-  const PrintingPolicy &Policy;
-
-  /// Process clauses with list of variables.
-  template <typename T> void VisitOMPClauseList(T *Node, char StartSym);
-  /// Process motion clauses.
-  template <typename T> void VisitOMPMotionClause(T *Node);
-
-public:
-  OMPClausePrinter(raw_ostream &OS, const PrintingPolicy &Policy)
-      : OS(OS), Policy(Policy) {}
-
-#define GEN_CLANG_CLAUSE_CLASS
-#define CLAUSE_CLASS(Enum, Str, Class) void Visit##Class(Class *S);
-#include "llvm/Frontend/OpenMP/OMP.inc"
-};
-
+         
 struct OMPTraitProperty {
   llvm::omp::TraitProperty Kind = llvm::omp::TraitProperty::invalid;
 
@@ -8872,6 +8854,94 @@
   }
 };
 
+/// This captures 'when' clause in the '#pragma omp metadirective'
+/// \code
+/// #pragma omp metadirective when(user={consition(N<100)}:parallel for)
+/// \endcode
+/// In the above example, the metadirective clause has a condition which when
+/// satisfied will use the parallel for directive with the code enclosed by the
+/// directive.
+class OMPWhenClause final : public OMPClause {
+  friend class OMPClauseReader;
+
+  OMPTraitInfo *TI;
+  OpenMPDirectiveKind DKind;
+  Stmt *Directive;
+
+  /// Location of '('.
+  SourceLocation LParenLoc;
+
+public:
+  /// Build 'when' clause with argument \a A ('none' or 'shared').
+  ///
+  /// \param T TraitInfor containing information about the context selector
+  /// \param DKind The directive associated with the when clause
+  /// \param D The statement associated with the when clause
+  /// \param StartLoc Starting location of the clause.
+  /// \param LParenLoc Location of '('.
+  /// \param EndLoc Ending location of the clause.
+  OMPWhenClause(OMPTraitInfo &T, OpenMPDirectiveKind dKind, Stmt *D,
+                SourceLocation StartLoc, SourceLocation LParenLoc,
+                SourceLocation EndLoc)
+      : OMPClause(llvm::omp::OMPC_when, StartLoc, EndLoc), TI(&T), DKind(dKind),
+        Directive(D), LParenLoc(LParenLoc) {}
+
+  /// Build an empty clause.
+  OMPWhenClause()
+      : OMPClause(llvm::omp::OMPC_when, SourceLocation(), SourceLocation()) {}
+
+  /// Sets the location of '('.
+  void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
+
+  /// Returns the location of '('.
+  SourceLocation getLParenLoc() const { return LParenLoc; }
+
+  /// Returns the directive variant kind
+  OpenMPDirectiveKind getDKind() { return DKind; }
+
+  Stmt *getDirective() const { return Directive; }
+
+  /// Returns the OMPTraitInfo
+  OMPTraitInfo &getTI() { return *TI; }
+
+  child_range children() {
+    return child_range(child_iterator(), child_iterator());
+  }
+
+  const_child_range children() const {
+    return const_child_range(const_child_iterator(), const_child_iterator());
+  }
+  child_range used_children() {
+    return child_range(child_iterator(), child_iterator());
+  }
+  const_child_range used_children() const {
+    return const_child_range(const_child_iterator(), const_child_iterator());
+  }
+
+  static bool classof(const OMPClause *T) {
+    return T->getClauseKind() == llvm::omp::OMPC_when;
+  }
+};
+
+class OMPClausePrinter final : public OMPClauseVisitor<OMPClausePrinter> {
+  raw_ostream &OS;
+  const PrintingPolicy &Policy;
+
+  /// Process clauses with list of variables.
+  template <typename T> void VisitOMPClauseList(T *Node, char StartSym);
+  /// Process motion clauses.
+  template <typename T> void VisitOMPMotionClause(T *Node);
+
+public:
+  OMPClausePrinter(raw_ostream &OS, const PrintingPolicy &Policy)
+      : OS(OS), Policy(Policy) {}
+      
+  void VisitOMPWhenClause(OMPWhenClause *Node);
+  
+#define GEN_CLANG_CLAUSE_CLASS
+#define CLAUSE_CLASS(Enum, Str, Class) void Visit##Class(Class *S);
+#include "llvm/Frontend/OpenMP/OMP.inc"
+};
 } // namespace clang
 
 #endif // LLVM_CLANG_AST_OPENMPCLAUSE_H
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to