================
@@ -5965,6 +5967,269 @@ static bool teamsLoopCanBeParallelFor(Stmt *AStmt, Sema 
&SemaRef) {
   return Checker.teamsLoopCanBeParallelFor();
 }
 
+static Expr *getInitialExprFromCapturedExpr(Expr *Cond) {
+
+  Expr *SubExpr = Cond->IgnoreParenImpCasts();
+
+  if (auto *DeclRef = dyn_cast<DeclRefExpr>(SubExpr)) {
+    if (auto *CapturedExprDecl =
+            dyn_cast<OMPCapturedExprDecl>(DeclRef->getDecl())) {
+
+      // Retrieve the initial expression from the captured expression
+      return CapturedExprDecl->getInit();
+    }
+  }
+  return nullptr;
+}
+
+static Expr *replaceWithNewTraitsOrDirectCall(const ASTContext &Context, Expr 
*,
+                                              SemaOpenMP *, bool);
+
+/// cloneAssociatedStmt() function is for cloning the Associated Statement
+/// present with a Directive and then modifying it. By this we avoid modifying
+/// the original Associated Statement.
+static StmtResult cloneAssociatedStmt(const ASTContext &Context, Stmt *StmtP,
+                                      SemaOpenMP *SemaPtr, bool NoContext) {
+  StmtResult ResultAssocStmt;
+  if (auto *AssocStmt = dyn_cast<CapturedStmt>(StmtP)) {
+    CapturedDecl *CDecl = AssocStmt->getCapturedDecl();
+    Stmt *AssocExprStmt = AssocStmt->getCapturedStmt();
+    auto *AssocExpr = dyn_cast<Expr>(AssocExprStmt);
+    Expr *NewCallOrPseudoObjOrBinExpr = replaceWithNewTraitsOrDirectCall(
+        Context, AssocExpr, SemaPtr, NoContext);
+
+    // Copy Current Captured Decl to a New Captured Decl for noting the
+    // Annotation
+    CapturedDecl *NewDecl =
+        CapturedDecl::Create(const_cast<ASTContext &>(Context),
+                             CDecl->getDeclContext(), CDecl->getNumParams());
+    NewDecl->setBody(static_cast<Stmt *>(NewCallOrPseudoObjOrBinExpr));
+    for (unsigned I : llvm::seq<unsigned>(CDecl->getNumParams())) {
+      if (I != CDecl->getContextParamPosition())
+        NewDecl->setParam(I, CDecl->getParam(I));
+      else
+        NewDecl->setContextParam(I, CDecl->getContextParam());
+    }
+
+    // Create a New Captured Stmt containing the New Captured Decl
+    SmallVector<CapturedStmt::Capture, 4> Captures;
+    SmallVector<Expr *, 4> CaptureInits;
+    for (const CapturedStmt::Capture &Capture : AssocStmt->captures())
+      Captures.push_back(Capture);
+    for (Expr *CaptureInit : AssocStmt->capture_inits())
+      CaptureInits.push_back(CaptureInit);
+    auto *NewStmt = CapturedStmt::Create(
+        Context, AssocStmt->getCapturedStmt(),
+        AssocStmt->getCapturedRegionKind(), Captures, CaptureInits, NewDecl,
+        const_cast<RecordDecl *>(AssocStmt->getCapturedRecordDecl()));
+
+    ResultAssocStmt = NewStmt;
+  }
+  return ResultAssocStmt;
+}
+
+/// replaceWithNewTraitsOrDirectCall() is for transforming the call traits.
+/// Call traits associated with a function call are removed and replaced with
+/// a direct call. For clause "nocontext" only, the direct call is then
+/// modified to have call traits for a non-dispatch variant.
+static Expr *replaceWithNewTraitsOrDirectCall(const ASTContext &Context,
+                                              Expr *AssocExpr,
+                                              SemaOpenMP *SemaPtr,
+                                              bool NoContext) {
+  BinaryOperator *BinaryCopyOpr = nullptr;
+  bool IsBinaryOp = false;
+  Expr *PseudoObjExprOrCall = AssocExpr;
+  if (auto *BinOprExpr = dyn_cast<BinaryOperator>(AssocExpr)) {
+    IsBinaryOp = true;
+    BinaryCopyOpr = BinaryOperator::Create(
+        Context, BinOprExpr->getLHS(), BinOprExpr->getRHS(),
+        BinOprExpr->getOpcode(), BinOprExpr->getType(),
+        BinOprExpr->getValueKind(), BinOprExpr->getObjectKind(),
+        BinOprExpr->getOperatorLoc(), FPOptionsOverride());
+    PseudoObjExprOrCall = BinaryCopyOpr->getRHS();
+  }
+
+  Expr *CallWithoutInvariants = PseudoObjExprOrCall;
+  // Change PseudoObjectExpr to a direct call
+  if (auto *PseudoObjExpr = dyn_cast<PseudoObjectExpr>(PseudoObjExprOrCall))
+    CallWithoutInvariants = *((PseudoObjExpr->semantics_begin()) - 1);
+
+  Expr *FinalCall = CallWithoutInvariants; // For noinvariants clause
+  if (NoContext) {
+    // example to explain the changes done for "nocontext" clause:
+    //
+    // #pragma omp declare variant(foo_variant_dispatch)
+    //                                  match(construct = {dispatch})
+    // #pragma omp declare variant(foo_variant_allCond)
+    //                                 match(user = {condition(1)})
+    // ...
+    //     #pragma omp dispatch nocontext(cond_true)
+    //         foo(i, j); // with traits: CodeGen call to
+    //         foo_variant_dispatch(i,j)
+    // dispatch construct is changed to:
+    // if (cond_true) {
+    //    foo(i,j) // with traits: CodeGen call to foo_variant_allCond(i,j)
+    // } else {
+    //   #pragma omp dispatch
+    //   foo(i,j)  // with traits: CodeGen call to foo_variant_dispatch(i,j)
+    // }
----------------
alexey-bataev wrote:

There are some extra additions. Sema should not produce extra OpenMP 
constructs. You can generate required CaturedStmt (one for outer  taskwait and 
one for inner dispatch) and then use them to generate in codegen. Sema should 
be as close to the source code as possible.

https://github.com/llvm/llvm-project/pull/117904
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to