grokos created this revision.
grokos added a project: clang.

This patch completes the support for the "declare target" directive in Sema. 
With this patch Sema handles implicitly used functions (i.e. functions which 
are used inside a target region without having been "declared target") 
including lambdas, templated functions, functions called from within target 
functions and ctors/dtors.

By default, use of implicit declare target functions is enabled. An upcoming 
driver patch will change that.


Repository:
  rL LLVM

https://reviews.llvm.org/D38798

Files:
  include/clang/Basic/LangOptions.def
  include/clang/Sema/Sema.h
  include/clang/Sema/SemaInternal.h
  lib/Parse/ParseOpenMP.cpp
  lib/Sema/SemaDecl.cpp
  lib/Sema/SemaOpenMP.cpp

Index: lib/Sema/SemaOpenMP.cpp
===================================================================
--- lib/Sema/SemaOpenMP.cpp
+++ lib/Sema/SemaOpenMP.cpp
@@ -19,6 +19,7 @@
 #include "clang/AST/Decl.h"
 #include "clang/AST/DeclCXX.h"
 #include "clang/AST/DeclOpenMP.h"
+#include "clang/AST/RecursiveASTVisitor.h"
 #include "clang/AST/StmtCXX.h"
 #include "clang/AST/StmtOpenMP.h"
 #include "clang/AST/StmtVisitor.h"
@@ -1139,6 +1140,124 @@
   return false;
 }
 
+namespace {
+/// Visit actual function body and its associated nested functions bodies.
+class ImplicitDeviceFunctionChecker
+    : public RecursiveASTVisitor<ImplicitDeviceFunctionChecker> {
+  Sema &SemaRef;
+
+public:
+  ImplicitDeviceFunctionChecker(Sema &SemaReference) : SemaRef(SemaReference){};
+
+  /// Traverse body of lambda, and mark it the with OMPDeclareTargetDeclAttr
+  bool TraverseLambdaCapture(LambdaExpr *LE, const LambdaCapture *C,
+                             Expr *Init);
+
+  /// Traverse FunctionDecl and mark it the with OMPDeclareTargetDeclAttr
+  bool VisitFunctionDecl(FunctionDecl *F);
+
+  /// Traverse Callee of Calexpr and mark it the with OMPDeclareTargetDeclAttr
+  bool VisitCallExpr(CallExpr *Call);
+
+  /// Traverse Constructs and mark it the with OMPDeclareTargetDeclAttr
+  bool VisitCXXConstructExpr(CXXConstructExpr *E);
+
+  /// Traverse Destructor and mark it the with OMPDeclareTargetDeclAttr
+  bool VisitCXXDestructorDecl(CXXDestructorDecl *D);
+};
+}
+
+/// Traverse declaration of /param D to check whether it has
+/// OMPDeclareTargetDeclAttr or not. If so, it marks definition with
+/// OMPDeclareTargetDeclAttr.
+static void ImplicitDeclareTargetCheck(Sema &SemaRef, Decl *D) {
+  if (SemaRef.getLangOpts().OpenMPImplicitDeclareTarget) {
+    // Structured block of target region is visited to catch function call.
+    // Revealed function calls are marked with OMPDeclareTargetDeclAttr
+    // attribute,
+    // in case -fopenmp-implicit-declare-target extension is enabled.
+    ImplicitDeviceFunctionChecker FunctionCallChecker(SemaRef);
+    FunctionCallChecker.TraverseDecl(D);
+  }
+}
+
+/// Traverse declaration of /param D to check whether it has
+/// OMPDeclareTargetDeclAttr or not. If so, it marks definition with
+/// OMPDeclareTargetDeclAttr.
+void Sema::checkDeclImplicitlyUsedOpenMPTargetContext(Decl *D) {
+  if (!D || D->isInvalidDecl())
+    return;
+
+  if (FunctionDecl *FD = dyn_cast<FunctionDecl>(D)) {
+    if (FD->hasBody()) {
+      for (auto RI : FD->redecls()) {
+        if (RI->hasAttr<OMPDeclareTargetDeclAttr>()) {
+          Attr *A = OMPDeclareTargetDeclAttr::CreateImplicit(
+              Context, OMPDeclareTargetDeclAttr::MT_To);
+          D->addAttr(A);
+
+          ImplicitDeclareTargetCheck(*this, FD);
+          return;
+        }
+      }
+    }
+  }
+  return;
+}
+
+bool ImplicitDeviceFunctionChecker::TraverseLambdaCapture(
+    LambdaExpr *LE, const LambdaCapture *C, Expr *Init) {
+  if (CXXRecordDecl *Class = LE->getLambdaClass())
+    if (!Class->hasAttr<OMPDeclareTargetDeclAttr>()) {
+      Attr *A = OMPDeclareTargetDeclAttr::CreateImplicit(
+          SemaRef.Context, OMPDeclareTargetDeclAttr::MT_To);
+      Class->addAttr(A);
+    }
+
+  TraverseStmt(LE->getBody());
+  return true;
+}
+
+bool ImplicitDeviceFunctionChecker::VisitFunctionDecl(FunctionDecl *F) {
+  assert(F);
+  if (!F->hasAttr<OMPDeclareTargetDeclAttr>()) {
+    Attr *A = OMPDeclareTargetDeclAttr::CreateImplicit(
+        SemaRef.Context, OMPDeclareTargetDeclAttr::MT_To);
+    F->addAttr(A);
+    TraverseDecl(F);
+  }
+  return true;
+}
+
+bool ImplicitDeviceFunctionChecker::VisitCallExpr(CallExpr *Call) {
+  if (FunctionDecl *Callee = Call->getDirectCallee()) {
+    return VisitFunctionDecl(Callee);
+  }
+  return true;
+}
+
+bool ImplicitDeviceFunctionChecker::VisitCXXConstructExpr(CXXConstructExpr *E) {
+  CXXConstructorDecl *Constructor = E->getConstructor();
+  // When constructor is invoked, it is checked whether the object has
+  // destructor or not. In case it has destructor, destructor is automatically
+  // marked with declare target attribute since it is needed to emit for device,
+  QualType Ty = E->getType();
+  const RecordType *RT =
+      SemaRef.Context.getBaseElementType(Ty)->getAs<RecordType>();
+  CXXRecordDecl *RD = cast<CXXRecordDecl>(RT->getDecl());
+
+  if (auto *Destructor = RD->getDestructor())
+    VisitCXXDestructorDecl(Destructor);
+
+  return VisitFunctionDecl(Constructor);
+}
+
+bool ImplicitDeviceFunctionChecker::VisitCXXDestructorDecl(
+    CXXDestructorDecl *D) {
+  assert(D);
+  return VisitFunctionDecl(D);
+}
+
 void Sema::InitDataSharingAttributesStack() {
   VarDataSharingAttributesStack = new DSAStackTy(*this);
 }
@@ -1304,12 +1423,12 @@
   // If we are attempting to capture a global variable in a directive with
   // 'target' we return true so that this global is also mapped to the device.
   //
-  // FIXME: If the declaration is enclosed in a 'declare target' directive,
-  // then it should not be captured. Therefore, an extra check has to be
-  // inserted here once support for 'declare target' is added.
+  // If the variable is enclosed in a declare target directive, that is not
+  // required.
   //
   auto *VD = dyn_cast<VarDecl>(D);
-  if (VD && !VD->hasLocalStorage()) {
+  if (VD && !VD->hasLocalStorage() &&
+        !VD->hasAttr<OMPDeclareTargetDeclAttr>()) {
     if (isOpenMPTargetExecutionDirective(DSAStack->getCurrentDirective()) &&
         !DSAStack->isClauseParsingMode())
       return VD;
@@ -6270,6 +6389,8 @@
 
   getCurFunction()->setHasBranchProtectedScope();
 
+  ImplicitDeclareTargetCheck(*this, CS->getCapturedDecl());
+
   return OMPTargetDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt);
 }
 
@@ -6290,6 +6411,8 @@
 
   getCurFunction()->setHasBranchProtectedScope();
 
+  ImplicitDeclareTargetCheck(*this, CS->getCapturedDecl());
+
   return OMPTargetParallelDirective::Create(Context, StartLoc, EndLoc, Clauses,
                                             AStmt);
 }
@@ -6334,6 +6457,9 @@
   }
 
   getCurFunction()->setHasBranchProtectedScope();
+
+  ImplicitDeclareTargetCheck(*this, CS->getCapturedDecl());
+
   return OMPTargetParallelForDirective::Create(Context, StartLoc, EndLoc,
                                                NestedLoopCount, Clauses, AStmt,
                                                B, DSAStack->isCancelRegion());
@@ -6778,6 +6904,9 @@
     return StmtError();
 
   getCurFunction()->setHasBranchProtectedScope();
+
+  ImplicitDeclareTargetCheck(*this, CS->getCapturedDecl());
+
   return OMPTargetParallelForSimdDirective::Create(
       Context, StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B);
 }
@@ -6825,6 +6954,9 @@
     return StmtError();
 
   getCurFunction()->setHasBranchProtectedScope();
+
+  ImplicitDeclareTargetCheck(*this, CS->getCapturedDecl());
+
   return OMPTargetSimdDirective::Create(Context, StartLoc, EndLoc,
                                         NestedLoopCount, Clauses, AStmt, B);
 }
@@ -6906,6 +7038,9 @@
     return StmtError();
 
   getCurFunction()->setHasBranchProtectedScope();
+
+  ImplicitDeclareTargetCheck(*this, CS->getCapturedDecl());
+
   return OMPTeamsDistributeSimdDirective::Create(
       Context, StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B);
 }
@@ -7020,6 +7155,8 @@
 
   getCurFunction()->setHasBranchProtectedScope();
 
+  ImplicitDeclareTargetCheck(*this, CS->getCapturedDecl());
+
   return OMPTargetTeamsDirective::Create(Context, StartLoc, EndLoc, Clauses,
                                          AStmt);
 }
@@ -7054,6 +7191,9 @@
          "omp target teams distribute loop exprs were not built");
 
   getCurFunction()->setHasBranchProtectedScope();
+
+  ImplicitDeclareTargetCheck(*this, CS->getCapturedDecl());
+
   return OMPTargetTeamsDistributeDirective::Create(
       Context, StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B);
 }
@@ -7099,6 +7239,9 @@
   }
 
   getCurFunction()->setHasBranchProtectedScope();
+
+  ImplicitDeclareTargetCheck(*this, CS->getCapturedDecl());
+
   return OMPTargetTeamsDistributeParallelForDirective::Create(
       Context, StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B);
 }
@@ -7145,6 +7288,9 @@
   }
 
   getCurFunction()->setHasBranchProtectedScope();
+
+  ImplicitDeclareTargetCheck(*this, CS->getCapturedDecl());
+
   return OMPTargetTeamsDistributeParallelForSimdDirective::Create(
       Context, StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B);
 }
@@ -7178,6 +7324,9 @@
          "omp target teams distribute simd loop exprs were not built");
 
   getCurFunction()->setHasBranchProtectedScope();
+
+  ImplicitDeclareTargetCheck(*this, CS->getCapturedDecl());
+
   return OMPTargetTeamsDistributeSimdDirective::Create(
       Context, StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B);
 }
@@ -12041,17 +12190,18 @@
     // target region (it can be e.g. a lambda) that is legal and we do not need
     // to do anything else.
     if (LD == D) {
-      Attr *A = OMPDeclareTargetDeclAttr::CreateImplicit(
-          SemaRef.Context, OMPDeclareTargetDeclAttr::MT_To);
-      D->addAttr(A);
-      if (ASTMutationListener *ML = SemaRef.Context.getASTMutationListener())
-        ML->DeclarationMarkedOpenMPDeclareTarget(D, A);
+      if (!SemaRef.getLangOpts().OpenMPImplicitDeclareTarget)
+        if (!D->hasAttr<OMPDeclareTargetDeclAttr>())
+          SemaRef.Diag(LD->getLocation(), diag::warn_omp_not_in_target_context);
+
       return;
     }
   }
   if (!LD)
     LD = D;
-  if (LD && !LD->hasAttr<OMPDeclareTargetDeclAttr>() &&
+  // The parameters of a function are considered 'declare target' declarations
+  // if the function itself is 'declare target'.
+  if (LD && !LD->hasAttr<OMPDeclareTargetDeclAttr>() && !isa<ParmVarDecl>(LD) &&
       (isa<VarDecl>(LD) || isa<FunctionDecl>(LD))) {
     // Outlined declaration is not declared target.
     if (LD->isOutOfLine()) {
@@ -12120,6 +12270,16 @@
       return;
     }
   }
+  if (TemplateDecl *TD = dyn_cast<TemplateDecl>(D)) {
+    // Mark template declarations as declare target so that they can propagate
+    // that information to their instances.
+    Attr *A = OMPDeclareTargetDeclAttr::CreateImplicit(
+        Context, OMPDeclareTargetDeclAttr::MT_To);
+    TD->addAttr(A);
+    if (ASTMutationListener *ML = Context.getASTMutationListener())
+      ML->DeclarationMarkedOpenMPDeclareTarget(TD, A);
+    return;
+  }
   if (!E) {
     // Checking declaration inside declare target region.
     if (!D->hasAttr<OMPDeclareTargetDeclAttr>() &&
Index: lib/Sema/SemaDecl.cpp
===================================================================
--- lib/Sema/SemaDecl.cpp
+++ lib/Sema/SemaDecl.cpp
@@ -6746,7 +6746,8 @@
       case SC_Register:
         // Local Named register
         if (!Context.getTargetInfo().isValidGCCRegisterName(Label) &&
-            DeclAttrsMatchCUDAMode(getLangOpts(), getCurFunctionDecl()))
+            DeclAttrsMatchOffloadMode(getLangOpts(), getCurFunctionDecl(),
+                                      IsInOpenMPDeclareTargetContext))
           Diag(E->getExprLoc(), diag::err_asm_unknown_register_name) << Label;
         break;
       case SC_Static:
@@ -6756,7 +6757,8 @@
       }
     } else if (SC == SC_Register) {
       // Global Named register
-      if (DeclAttrsMatchCUDAMode(getLangOpts(), NewVD)) {
+      if (DeclAttrsMatchOffloadMode(getLangOpts(), NewVD,
+                                    IsInOpenMPDeclareTargetContext)) {
         const auto &TI = Context.getTargetInfo();
         bool HasSizeMismatch;
 
@@ -12656,6 +12658,11 @@
     DiscardCleanupsInEvaluationContext();
   }
 
+  // In case of OpenMPImplicitDeclareTarget, semantically parsed function body
+  // is visited to mark inner callexpr with OMPDeclareTargetDeclAttr attribute.
+  if (getLangOpts().OpenMP && getLangOpts().OpenMPImplicitDeclareTarget)
+    checkDeclImplicitlyUsedOpenMPTargetContext(dcl);
+
   return dcl;
 }
 
Index: lib/Parse/ParseOpenMP.cpp
===================================================================
--- lib/Parse/ParseOpenMP.cpp
+++ lib/Parse/ParseOpenMP.cpp
@@ -757,6 +757,7 @@
     if (!Actions.ActOnStartOpenMPDeclareTargetDirective(DTLoc))
       return DeclGroupPtrTy();
 
+    SmallVector<Decl *, 32> Decls;
     DKind = ParseOpenMPDirectiveKind(*this);
     while (DKind != OMPD_end_declare_target && DKind != OMPD_declare_target &&
            Tok.isNot(tok::eof) && Tok.isNot(tok::r_brace)) {
@@ -780,6 +781,12 @@
         else
           TPA.Commit();
       }
+
+      // Save the declarations so that we can create the declare target group
+      // later on.
+      if (Ptr)
+        for (auto *V : Ptr.get())
+          Decls.push_back(V);
     }
 
     if (DKind == OMPD_end_declare_target) {
@@ -794,8 +801,17 @@
     } else {
       Diag(Tok, diag::err_expected_end_declare_target);
       Diag(DTLoc, diag::note_matching) << "'#pragma omp declare target'";
+      // We have an error, so we don't have to attempt to generate code for the
+      // declarations.
+      Decls.clear();
     }
     Actions.ActOnFinishOpenMPDeclareTargetDirective();
+
+    // If we have decls generate the group so that code can be generated for it
+    // later on.
+    if (!Decls.empty())
+      return Actions.BuildDeclaratorGroup(Decls);
+
     return DeclGroupPtrTy();
   }
   case OMPD_unknown:
Index: include/clang/Sema/SemaInternal.h
===================================================================
--- include/clang/Sema/SemaInternal.h
+++ include/clang/Sema/SemaInternal.h
@@ -60,6 +60,16 @@
   return isDeviceSideDecl == LangOpts.CUDAIsDevice;
 }
 
+// Helper function to check whether D's attributes match current offloading
+// mode.
+inline bool DeclAttrsMatchOffloadMode(const LangOptions &LangOpts, Decl *D,
+                                      bool InOpenMPDeviceRegion) {
+  if (LangOpts.OpenMPIsDevice)
+    return InOpenMPDeviceRegion;
+
+  return DeclAttrsMatchCUDAMode(LangOpts, D);
+}
+
 // Directly mark a variable odr-used. Given a choice, prefer to use 
 // MarkVariableReferenced since it does additional checks and then 
 // calls MarkVarDeclODRUsed.
Index: include/clang/Sema/Sema.h
===================================================================
--- include/clang/Sema/Sema.h
+++ include/clang/Sema/Sema.h
@@ -8658,6 +8658,9 @@
   bool isInOpenMPDeclareTargetContext() const {
     return IsInOpenMPDeclareTargetContext;
   }
+  /// Check and mark declarations that are implicitly used inside OpenMP target
+  /// regions.
+  void checkDeclImplicitlyUsedOpenMPTargetContext(Decl *D);
 
   /// Return the number of captured regions created for an OpenMP directive.
   static int getOpenMPCaptureLevels(OpenMPDirectiveKind Kind);
Index: include/clang/Basic/LangOptions.def
===================================================================
--- include/clang/Basic/LangOptions.def
+++ include/clang/Basic/LangOptions.def
@@ -190,6 +190,7 @@
 LANGOPT(CUDA              , 1, 0, "CUDA")
 LANGOPT(OpenMP            , 32, 0, "OpenMP support and version of OpenMP (31, 40 or 45)")
 LANGOPT(OpenMPUseTLS      , 1, 0, "Use TLS for threadprivates or runtime calls")
+LANGOPT(OpenMPImplicitDeclareTarget , 1, 0, "Enable implicit declare target extension - marks automatically declarations and definitions with declare target attribute")
 LANGOPT(OpenMPIsDevice    , 1, 0, "Generate code only for OpenMP target device")
 LANGOPT(RenderScript      , 1, 0, "RenderScript")
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to