hliao created this revision.
hliao added reviewers: tra, rsmith, yaxunl.
Herald added a reviewer: martong.
Herald added a reviewer: shafik.
Herald added a project: clang.
Herald added a subscriber: cfe-commits.

- HIP/CUDA host side needs to use device kernel symbol name to match the device 
side binaries. Without a consistent naming between host- and device-side 
compilations, it's risky that wrong device binaries are executed. Consistent 
naming is usually not an issue until unnamed types are used, especially the 
lambda. In this patch, the consistent name mangling is addressed for the 
extended lambdas, i.e. the lambdas annotated with `__device__`.
- In [Itanium C++ ABI][1], the mangling of the lambda is generally unspecified 
unless, in certain cases, ODR rule is required to ensure consisent naming cross 
TUs. The extended lambda is such a case as its name may be part of a device 
kernel function, e.g., the extended lambda is used as a template argument and 
etc. Thus, we need to force ODR for extended lambdas as they are referenced in 
both device- and host-side TUs. Furthermore, if a extended lambda is nested in 
other (extended or not) lambdas, those lambdas are required to follow ODR 
naming as well. This patch revises the current lambda mangle numbering to force 
ODR from an extended lambda to all its parent lambdas.
- On the other side, the aforementioned ODR naming should not change those 
lambdas' original linkages, i.e., we cannot replace the original `internal` 
with `linkonce_odr`; otherwise, we may violate ODR in general. This patch 
introduces a new field `HasKnownInternalLinkage` in lambda data to decouple the 
current linkage calculation based on mangling number assigned.

[1]: https://itanium-cxx-abi.github.io/cxx-abi/abi.html


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D68818

Files:
  clang/include/clang/AST/DeclCXX.h
  clang/include/clang/Sema/Sema.h
  clang/lib/AST/ASTImporter.cpp
  clang/lib/AST/Decl.cpp
  clang/lib/AST/DeclCXX.cpp
  clang/lib/Sema/SemaLambda.cpp
  clang/lib/Sema/TreeTransform.h
  clang/lib/Serialization/ASTReaderDecl.cpp
  clang/lib/Serialization/ASTWriter.cpp
  clang/test/CodeGenCUDA/unnamed-types.cu

Index: clang/test/CodeGenCUDA/unnamed-types.cu
===================================================================
--- /dev/null
+++ clang/test/CodeGenCUDA/unnamed-types.cu
@@ -0,0 +1,39 @@
+// RUN: %clang_cc1 -std=c++11 -x hip -triple x86_64-linux-gnu -emit-llvm %s -o - | FileCheck %s --check-prefix=HOST
+// RUN: %clang_cc1 -std=c++11 -x hip -triple amdgcn-amd-amdhsa -fcuda-is-device -emit-llvm %s -o - | FileCheck %s --check-prefix=DEVICE
+
+#include "Inputs/cuda.h"
+
+// HOST: @0 = private unnamed_addr constant [43 x i8] c"_Z2k0IZZ2f1PfENKUlS0_E_clES0_EUlfE_EvS0_T_\00", align 1
+
+__device__ float d0(float x) {
+  return [](float x) { return x + 2.f; }(x);
+}
+
+__device__ float d1(float x) {
+  return [](float x) { return x * 2.f; }(x);
+}
+
+// DEVICE: amdgpu_kernel void @_Z2k0IZZ2f1PfENKUlS0_E_clES0_EUlfE_EvS0_T_(
+template <typename F>
+__global__ void k0(float *p, F f) {
+  p[0] = f(p[0]) + d0(p[1]) + d1(p[2]);
+}
+
+void f0(float *p) {
+  [](float *p) {
+    *p = 1.f;
+  }(p);
+}
+
+// The inner/outer lambdas are required to be mangled following ODR but their
+// linkages are still required to keep the original `internal` linkage.
+
+// HOST: define internal void @_ZZ2f1PfENKUlS_E_clES_(
+// DEVICE: define internal float @_ZZZ2f1PfENKUlS_E_clES_ENKUlfE_clEf(
+void f1(float *p) {
+  [](float *p) {
+    k0<<<1,1>>>(p, [] __device__ (float x) { return x + 1.f; });
+  }(p);
+}
+// HOST: @__hip_register_globals
+// HOST: __hipRegisterFunction{{.*}}@_Z2k0IZZ2f1PfENKUlS0_E_clES0_EUlfE_EvS0_T_{{.*}}@0
Index: clang/lib/Serialization/ASTWriter.cpp
===================================================================
--- clang/lib/Serialization/ASTWriter.cpp
+++ clang/lib/Serialization/ASTWriter.cpp
@@ -6222,6 +6222,7 @@
     Record->push_back(Lambda.CaptureDefault);
     Record->push_back(Lambda.NumCaptures);
     Record->push_back(Lambda.NumExplicitCaptures);
+    Record->push_back(Lambda.HasKnownInternalLinkage);
     Record->push_back(Lambda.ManglingNumber);
     AddDeclRef(D->getLambdaContextDecl());
     AddTypeSourceInfo(Lambda.MethodTyInfo);
Index: clang/lib/Serialization/ASTReaderDecl.cpp
===================================================================
--- clang/lib/Serialization/ASTReaderDecl.cpp
+++ clang/lib/Serialization/ASTReaderDecl.cpp
@@ -1732,6 +1732,7 @@
     Lambda.CaptureDefault = Record.readInt();
     Lambda.NumCaptures = Record.readInt();
     Lambda.NumExplicitCaptures = Record.readInt();
+    Lambda.HasKnownInternalLinkage = Record.readInt();
     Lambda.ManglingNumber = Record.readInt();
     Lambda.ContextDecl = ReadDeclID();
     Lambda.Captures = (Capture *)Reader.getContext().Allocate(
Index: clang/lib/Sema/TreeTransform.h
===================================================================
--- clang/lib/Sema/TreeTransform.h
+++ clang/lib/Sema/TreeTransform.h
@@ -11373,17 +11373,18 @@
                                         E->getCaptureDefault());
   getDerived().transformedLocalDecl(OldClass, {Class});
 
-  Optional<std::pair<unsigned, Decl*>> Mangling;
+  Optional<std::tuple<unsigned, bool, Decl *>> Mangling;
   if (getDerived().ReplacingOriginal())
-    Mangling = std::make_pair(OldClass->getLambdaManglingNumber(),
-                              OldClass->getLambdaContextDecl());
+    Mangling = std::make_tuple(OldClass->getLambdaManglingNumber(),
+                               OldClass->hasKnownInternalLinkage(),
+                               OldClass->getLambdaContextDecl());
 
   // Build the call operator.
   CXXMethodDecl *NewCallOperator = getSema().startLambdaDefinition(
       Class, E->getIntroducerRange(), NewCallOpTSI,
       E->getCallOperator()->getEndLoc(),
       NewCallOpTSI->getTypeLoc().castAs<FunctionProtoTypeLoc>().getParams(),
-      E->getCallOperator()->getConstexprKind(), Mangling);
+      E->getCallOperator()->getConstexprKind());
 
   LSI->CallOperator = NewCallOperator;
 
@@ -11403,6 +11404,9 @@
   getDerived().transformAttrs(E->getCallOperator(), NewCallOperator);
   getDerived().transformedLocalDecl(E->getCallOperator(), {NewCallOperator});
 
+  // Number the lambda for linkage purposes if necessary.
+  getSema().handleLambdaNumbering(Class, NewCallOperator, Mangling);
+
   // Introduce the context of the call operator.
   Sema::ContextRAII SavedContext(getSema(), NewCallOperator,
                                  /*NewThisContext*/false);
Index: clang/lib/Sema/SemaLambda.cpp
===================================================================
--- clang/lib/Sema/SemaLambda.cpp
+++ clang/lib/Sema/SemaLambda.cpp
@@ -335,7 +335,7 @@
   case StaticDataMember:
     //  -- the initializers of nonspecialized static members of template classes
     if (!IsInNonspecializedTemplate)
-      return std::make_tuple(nullptr, nullptr);
+      return std::make_tuple(nullptr, ManglingContextDecl);
     // Fall through to get the current context.
     LLVM_FALLTHROUGH;
 
@@ -356,14 +356,15 @@
   llvm_unreachable("unexpected context");
 }
 
-CXXMethodDecl *Sema::startLambdaDefinition(
-    CXXRecordDecl *Class, SourceRange IntroducerRange,
-    TypeSourceInfo *MethodTypeInfo, SourceLocation EndLoc,
-    ArrayRef<ParmVarDecl *> Params, ConstexprSpecKind ConstexprKind,
-    Optional<std::pair<unsigned, Decl *>> Mangling) {
+CXXMethodDecl *Sema::startLambdaDefinition(CXXRecordDecl *Class,
+                                           SourceRange IntroducerRange,
+                                           TypeSourceInfo *MethodTypeInfo,
+                                           SourceLocation EndLoc,
+                                           ArrayRef<ParmVarDecl *> Params,
+                                           ConstexprSpecKind ConstexprKind) {
   QualType MethodType = MethodTypeInfo->getType();
   TemplateParameterList *TemplateParams =
-            getGenericLambdaTemplateParameterList(getCurLambda(), *this);
+      getGenericLambdaTemplateParameterList(getCurLambda(), *this);
   // If a lambda appears in a dependent context or is a generic lambda (has
   // template parameters) and has an 'auto' return type, deduce it to a
   // dependent type.
@@ -425,20 +426,72 @@
       P->setOwningFunction(Method);
   }
 
+  return Method;
+}
+
+void Sema::handleLambdaNumbering(
+    CXXRecordDecl *Class, CXXMethodDecl *Method,
+    Optional<std::tuple<unsigned, bool, Decl *>> Mangling) {
   if (Mangling) {
-    Class->setLambdaMangling(Mangling->first, Mangling->second);
-  } else {
-    MangleNumberingContext *MCtx;
+    unsigned ManglingNumber;
+    bool HasKnownInternalLinkage;
     Decl *ManglingContextDecl;
-    std::tie(MCtx, ManglingContextDecl) =
-        getCurrentMangleNumberContext(Class->getDeclContext());
-    if (MCtx) {
-      unsigned ManglingNumber = MCtx->getManglingNumber(Method);
-      Class->setLambdaMangling(ManglingNumber, ManglingContextDecl);
-    }
+    std::tie(ManglingNumber, HasKnownInternalLinkage, ManglingContextDecl) =
+        Mangling.getValue();
+    Class->setLambdaManglingNumber(ManglingNumber, HasKnownInternalLinkage);
+    Class->setLambdaContextDecl(ManglingContextDecl);
+    return;
   }
 
-  return Method;
+  auto getMangleNumberingContext =
+      [this](CXXRecordDecl *Class) -> MangleNumberingContext * {
+    // Get mangle numbering context if there's any extra decl context.
+    if (auto ManglingContextDecl = Class->getLambdaContextDecl())
+      return &Context.getManglingNumberContext(
+          ASTContext::NeedExtraManglingDecl, ManglingContextDecl);
+    // Otherwise, from that lambda's decl context.
+    auto DC = Class->getDeclContext();
+    while (auto *CD = dyn_cast<CapturedDecl>(DC))
+      DC = CD->getParent();
+    return &Context.getManglingNumberContext(DC);
+  };
+
+  MangleNumberingContext *MCtx;
+  Decl *ManglingContextDecl;
+  std::tie(MCtx, ManglingContextDecl) =
+      getCurrentMangleNumberContext(Class->getDeclContext());
+  // Remember the extra mangle context decl even this lambda is not numbered so
+  // far. A lambda may need numbering later once an inner nested lambda is an
+  // extended one.
+  Class->setLambdaContextDecl(ManglingContextDecl);
+  if (MCtx) {
+    unsigned ManglingNumber = MCtx->getManglingNumber(Method);
+    Class->setLambdaManglingNumber(ManglingNumber);
+  }
+  // For an extended lambda, i.e. a lambda annotated with `__device__`, it must
+  // have a mangling number so that it could be named consistently between
+  // host- and device-side compilations following ODR rule. Furthermore, if
+  // that extended lambda is nested within other lambdas, all its parent
+  // lambdas must have mangling number as well in order to have a consistently
+  // mangled name. But, the mangler needs informing those re-numbered lambdas
+  // still have `internal` linkage.
+  if (getLangOpts().CUDA && Method->hasAttr<CUDADeviceAttr>()) {
+    for (DeclContext *DC = Method; DC && !DC->isTranslationUnit();
+         DC = getLambdaAwareParentOfDeclContext(DC)) {
+      // Skip if it's not a lambda.
+      if (!isLambdaCallOperator(DC))
+        continue;
+      Method = cast<CXXMethodDecl>(DC);
+      Class = Method->getParent();
+      // Skip if it already has mangling number.
+      if (Class->getLambdaManglingNumber())
+        continue;
+      // Allocate/get mangle numbering context.
+      MCtx = getMangleNumberingContext(Class);
+      Class->setLambdaManglingNumber(MCtx->getManglingNumber(Method),
+                                     /*HasKnownInternalLinkage=*/true);
+    }
+  }
 }
 
 void Sema::buildLambdaScope(LambdaScopeInfo *LSI,
@@ -951,6 +1004,9 @@
   if (getLangOpts().CUDA)
     CUDASetLambdaAttrs(Method);
 
+  // Number the lambda for linkage purposes if necessary.
+  handleLambdaNumbering(Class, Method);
+
   // Introduce the function call operator as the current declaration context.
   PushDeclContext(CurScope, Method);
 
Index: clang/lib/AST/DeclCXX.cpp
===================================================================
--- clang/lib/AST/DeclCXX.cpp
+++ clang/lib/AST/DeclCXX.cpp
@@ -1487,6 +1487,9 @@
 
 Decl *CXXRecordDecl::getLambdaContextDecl() const {
   assert(isLambda() && "Not a lambda closure type!");
+  // Skip if this lambda is not numbered.
+  if (!getLambdaData().ManglingNumber)
+    return nullptr;
   ExternalASTSource *Source = getParentASTContext().getExternalSource();
   return getLambdaData().ContextDecl.get(Source);
 }
Index: clang/lib/AST/Decl.cpp
===================================================================
--- clang/lib/AST/Decl.cpp
+++ clang/lib/AST/Decl.cpp
@@ -1385,7 +1385,8 @@
     case Decl::CXXRecord: {
       const auto *Record = cast<CXXRecordDecl>(D);
       if (Record->isLambda()) {
-        if (!Record->getLambdaManglingNumber()) {
+        if (Record->hasKnownInternalLinkage() ||
+            !Record->getLambdaManglingNumber()) {
           // This lambda has no mangling number, so it's internal.
           return getInternalLinkageFor(D);
         }
@@ -1402,7 +1403,8 @@
         //  };
         const CXXRecordDecl *OuterMostLambda =
             getOutermostEnclosingLambda(Record);
-        if (!OuterMostLambda->getLambdaManglingNumber())
+        if (OuterMostLambda->hasKnownInternalLinkage() ||
+            !OuterMostLambda->getLambdaManglingNumber())
           return getInternalLinkageFor(D);
 
         return getLVForClosure(
Index: clang/lib/AST/ASTImporter.cpp
===================================================================
--- clang/lib/AST/ASTImporter.cpp
+++ clang/lib/AST/ASTImporter.cpp
@@ -2755,7 +2755,9 @@
       ExpectedDecl CDeclOrErr = import(DCXX->getLambdaContextDecl());
       if (!CDeclOrErr)
         return CDeclOrErr.takeError();
-      D2CXX->setLambdaMangling(DCXX->getLambdaManglingNumber(), *CDeclOrErr);
+      D2CXX->setLambdaManglingNumber(DCXX->getLambdaManglingNumber(),
+                                     DCXX->hasKnownInternalLinkage());
+      D2CXX->setLambdaContextDecl(*CDeclOrErr);
     } else if (DCXX->isInjectedClassName()) {
       // We have to be careful to do a similar dance to the one in
       // Sema::ActOnStartCXXMemberDeclarations
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -5853,12 +5853,17 @@
                                          LambdaCaptureDefault CaptureDefault);
 
   /// Start the definition of a lambda expression.
-  CXXMethodDecl *
-  startLambdaDefinition(CXXRecordDecl *Class, SourceRange IntroducerRange,
-                        TypeSourceInfo *MethodType, SourceLocation EndLoc,
-                        ArrayRef<ParmVarDecl *> Params,
-                        ConstexprSpecKind ConstexprKind,
-                        Optional<std::pair<unsigned, Decl *>> Mangling = None);
+  CXXMethodDecl *startLambdaDefinition(CXXRecordDecl *Class,
+                                       SourceRange IntroducerRange,
+                                       TypeSourceInfo *MethodType,
+                                       SourceLocation EndLoc,
+                                       ArrayRef<ParmVarDecl *> Params,
+                                       ConstexprSpecKind ConstexprKind);
+
+  /// Number lambda for linkage purposes if necessary.
+  void handleLambdaNumbering(
+      CXXRecordDecl *Class, CXXMethodDecl *Method,
+      Optional<std::tuple<unsigned, bool, Decl *>> Mangling = None);
 
   /// Endow the lambda scope info with the relevant properties.
   void buildLambdaScope(sema::LambdaScopeInfo *LSI,
Index: clang/include/clang/AST/DeclCXX.h
===================================================================
--- clang/include/clang/AST/DeclCXX.h
+++ clang/include/clang/AST/DeclCXX.h
@@ -584,7 +584,10 @@
     unsigned NumCaptures : 15;
 
     /// The number of explicit captures in this lambda.
-    unsigned NumExplicitCaptures : 13;
+    unsigned NumExplicitCaptures : 12;
+
+    /// Has known `internal` linkage.
+    unsigned HasKnownInternalLinkage : 1;
 
     /// The number used to indicate this lambda expression for name
     /// mangling in the Itanium C++ ABI.
@@ -603,12 +606,12 @@
     /// The type of the call method.
     TypeSourceInfo *MethodTyInfo;
 
-    LambdaDefinitionData(CXXRecordDecl *D, TypeSourceInfo *Info,
-                         bool Dependent, bool IsGeneric,
-                         LambdaCaptureDefault CaptureDefault)
-      : DefinitionData(D), Dependent(Dependent), IsGenericLambda(IsGeneric),
-        CaptureDefault(CaptureDefault), NumCaptures(0), NumExplicitCaptures(0),
-        MethodTyInfo(Info) {
+    LambdaDefinitionData(CXXRecordDecl *D, TypeSourceInfo *Info, bool Dependent,
+                         bool IsGeneric, LambdaCaptureDefault CaptureDefault)
+        : DefinitionData(D), Dependent(Dependent), IsGenericLambda(IsGeneric),
+          CaptureDefault(CaptureDefault), NumCaptures(0),
+          NumExplicitCaptures(0), HasKnownInternalLinkage(0),
+          MethodTyInfo(Info) {
       IsLambda = true;
 
       // C++1z [expr.prim.lambda]p4:
@@ -1902,6 +1905,13 @@
     return getLambdaData().ManglingNumber;
   }
 
+  /// The lambda is known to has internal linkage no matter whether it has name
+  /// mangling number.
+  bool hasKnownInternalLinkage() const {
+    assert(isLambda() && "Not a lambda closure type!");
+    return getLambdaData().HasKnownInternalLinkage;
+  }
+
   /// Retrieve the declaration that provides additional context for a
   /// lambda, when the normal declaration context is not specific enough.
   ///
@@ -1913,11 +1923,18 @@
   /// the declaration context suffices.
   Decl *getLambdaContextDecl() const;
 
+  void setLambdaContextDecl(Decl *ContextDecl) {
+    assert(isLambda() && "Not a lambda closure type!");
+    getLambdaData().ContextDecl = ContextDecl;
+  }
+
   /// Set the mangling number and context declaration for a lambda
   /// class.
-  void setLambdaMangling(unsigned ManglingNumber, Decl *ContextDecl) {
+  void setLambdaManglingNumber(unsigned ManglingNumber,
+                               bool HasKnownInternalLinkage = false) {
+    assert(isLambda() && "Not a lambda closure type!");
     getLambdaData().ManglingNumber = ManglingNumber;
-    getLambdaData().ContextDecl = ContextDecl;
+    getLambdaData().HasKnownInternalLinkage = HasKnownInternalLinkage;
   }
 
   /// Returns the inheritance model used for this record.
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to