v.g.vassilev updated this revision to Diff 137902.
v.g.vassilev added a comment.

Add sanity check.

We assert that the ODR hash of the template arguments should be the same as the 
one for from the actually found template specialization.

This way we managed to catch a few collisions in the ODRHash logic.


https://reviews.llvm.org/D41416

Files:
  include/clang/AST/DeclTemplate.h
  lib/AST/DeclTemplate.cpp
  lib/AST/ODRHash.cpp
  lib/Serialization/ASTReader.cpp
  lib/Serialization/ASTReaderDecl.cpp
  lib/Serialization/ASTWriter.cpp
  lib/Serialization/ASTWriterDecl.cpp
  test/Modules/cxx-templates.cpp

Index: test/Modules/cxx-templates.cpp
===================================================================
--- test/Modules/cxx-templates.cpp
+++ test/Modules/cxx-templates.cpp
@@ -249,24 +249,24 @@
 
 // CHECK-DUMP:      ClassTemplateDecl {{.*}} <{{.*[/\\]}}cxx-templates-common.h:1:1, {{.*}}>  col:{{.*}} in cxx_templates_common SomeTemplate
 // CHECK-DUMP:        ClassTemplateSpecializationDecl {{.*}} prev {{.*}} SomeTemplate
-// CHECK-DUMP-NEXT:     TemplateArgument type 'char [2]'
+// CHECK-DUMP-NEXT:     TemplateArgument type 'char [1]'
 // CHECK-DUMP:        ClassTemplateSpecializationDecl {{.*}} SomeTemplate definition
 // CHECK-DUMP-NEXT:     DefinitionData
 // CHECK-DUMP-NEXT:       DefaultConstructor
 // CHECK-DUMP-NEXT:       CopyConstructor
 // CHECK-DUMP-NEXT:       MoveConstructor
 // CHECK-DUMP-NEXT:       CopyAssignment
 // CHECK-DUMP-NEXT:       MoveAssignment
 // CHECK-DUMP-NEXT:       Destructor
-// CHECK-DUMP-NEXT:     TemplateArgument type 'char [2]'
-// CHECK-DUMP:        ClassTemplateSpecializationDecl {{.*}} prev {{.*}} SomeTemplate
 // CHECK-DUMP-NEXT:     TemplateArgument type 'char [1]'
+// CHECK-DUMP:        ClassTemplateSpecializationDecl {{.*}} prev {{.*}} SomeTemplate
+// CHECK-DUMP-NEXT:     TemplateArgument type 'char [2]'
 // CHECK-DUMP:        ClassTemplateSpecializationDecl {{.*}} SomeTemplate definition
 // CHECK-DUMP-NEXT:     DefinitionData
 // CHECK-DUMP-NEXT:       DefaultConstructor
 // CHECK-DUMP-NEXT:       CopyConstructor
 // CHECK-DUMP-NEXT:       MoveConstructor
 // CHECK-DUMP-NEXT:       CopyAssignment
 // CHECK-DUMP-NEXT:       MoveAssignment
 // CHECK-DUMP-NEXT:       Destructor
-// CHECK-DUMP-NEXT:     TemplateArgument type 'char [1]'
+// CHECK-DUMP-NEXT:     TemplateArgument type 'char [2]'
Index: lib/Serialization/ASTWriterDecl.cpp
===================================================================
--- lib/Serialization/ASTWriterDecl.cpp
+++ lib/Serialization/ASTWriterDecl.cpp
@@ -162,22 +162,61 @@
       Record.AddSourceLocation(typeParams->getRAngleLoc());
     }
 
-    /// Add to the record the first declaration from each module file that
-    /// provides a declaration of D. The intent is to provide a sufficient
-    /// set such that reloading this set will load all current redeclarations.
-    void AddFirstDeclFromEachModule(const Decl *D, bool IncludeLocal) {
-      llvm::MapVector<ModuleFile*, const Decl*> Firsts;
+    /// Collect the first declaration from each module file that provides a
+    /// declaration of D.
+    void CollectFirstDeclFromEachModule(const Decl *D, bool IncludeLocal,
+                            llvm::MapVector<ModuleFile*, const Decl*> &Firsts) {
+
       // FIXME: We can skip entries that we know are implied by others.
       for (const Decl *R = D->getMostRecentDecl(); R; R = R->getPreviousDecl()) {
         if (R->isFromASTFile())
           Firsts[Writer.Chain->getOwningModuleFile(R)] = R;
         else if (IncludeLocal)
           Firsts[nullptr] = R;
       }
+    }
+
+    /// Add to the record the first declaration from each module file that
+    /// provides a declaration of D. The intent is to provide a sufficient
+    /// set such that reloading this set will load all current redeclarations.
+    void AddFirstDeclFromEachModule(const Decl *D, bool IncludeLocal) {
+      llvm::MapVector<ModuleFile*, const Decl*> Firsts;
+      CollectFirstDeclFromEachModule(D, IncludeLocal, Firsts);
+
       for (const auto &F : Firsts)
         Record.AddDeclRef(F.second);
     }
 
+    /// Add to the record the first template specialization from each module
+    /// file that provides a declaration of D. We store the DeclId and an
+    /// ODRHash of the template arguments of D which should provide enough
+    /// information to load D only if the template instantiator needs it.
+    void AddFirstSpecializationDeclFromEachModule(const Decl *D,
+                                                  bool IncludeLocal) {
+      assert(isa<ClassTemplateSpecializationDecl>(D) ||
+             isa<VarTemplateSpecializationDecl>(D) || isa<FunctionDecl>(D) &&
+             "Must not be called with other decls");
+      llvm::MapVector<ModuleFile*, const Decl*> Firsts;
+      CollectFirstDeclFromEachModule(D, IncludeLocal, Firsts);
+
+      for (const auto &F : Firsts) {
+        Record.AddDeclRef(F.second);
+        ArrayRef<TemplateArgument> Args;
+        if (auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(D))
+          Args = CTSD->getTemplateArgs().asArray();
+        else if (auto *VTSD = dyn_cast<VarTemplateSpecializationDecl>(D))
+          Args = VTSD->getTemplateArgs().asArray();
+        else if (auto *FD = dyn_cast<FunctionDecl>(D))
+          Args = FD->getTemplateSpecializationArgs()->asArray();
+        assert(Args.size());
+        Record.push_back(TemplateArgumentList::ComputeODRHash(Args));
+        bool IsPartialSpecialization
+          = isa<ClassTemplatePartialSpecializationDecl>(D) ||
+          isa<VarTemplatePartialSpecializationDecl>(D);
+        Record.push_back(IsPartialSpecialization);
+      }
+    }
+
     /// Get the specialization decl from an entry in the specialization list.
     template <typename EntryType>
     typename RedeclarableTemplateDecl::SpecEntryTraits<EntryType>::DeclType *
@@ -190,7 +229,8 @@
     decltype(T::PartialSpecializations) &getPartialSpecializations(T *Common) {
       return Common->PartialSpecializations;
     }
-    ArrayRef<Decl> getPartialSpecializations(FunctionTemplateDecl::Common *) {
+    MutableArrayRef<FunctionTemplateSpecializationInfo>
+    getPartialSpecializations(FunctionTemplateDecl::Common *) {
       return None;
     }
 
@@ -207,9 +247,11 @@
         assert(!Common->LazySpecializations);
       }
 
-      ArrayRef<DeclID> LazySpecializations;
+      using LazySpecializationInfo
+        = RedeclarableTemplateDecl::LazySpecializationInfo;
+      ArrayRef<LazySpecializationInfo> LazySpecializations;
       if (auto *LS = Common->LazySpecializations)
-        LazySpecializations = llvm::makeArrayRef(LS + 1, LS[0]);
+        LazySpecializations = llvm::makeArrayRef(LS + 1, LS[0].DeclID);
 
       // Add a slot to the record for the number of specializations.
       unsigned I = Record.size();
@@ -225,12 +267,20 @@
 
       for (auto *D : Specs) {
         assert(D->isCanonicalDecl() && "non-canonical decl in set");
-        AddFirstDeclFromEachModule(D, /*IncludeLocal*/true);
+        AddFirstSpecializationDeclFromEachModule(D, /*IncludeLocal*/true);
+      }
+      for (auto &SpecInfo : LazySpecializations) {
+        Record.push_back(SpecInfo.DeclID);
+        Record.push_back(SpecInfo.ODRHash);
+        Record.push_back(SpecInfo.IsPartial);
       }
-      Record.append(LazySpecializations.begin(), LazySpecializations.end());
 
-      // Update the size entry we added earlier.
-      Record[I] = Record.size() - I - 1;
+      // Update the size entry we added earlier. We linerized the
+      // LazySpecializationInfo members and we need to adjust the size as we
+      // will read them always together.
+      assert ((Record.size() - I - 1) % 3 == 0
+              && "Must be divisible by LazySpecializationInfo count!");
+      Record[I] = (Record.size() - I - 1) / 3;
     }
 
     /// Ensure that this template specialization is associated with the specified
Index: lib/Serialization/ASTWriter.cpp
===================================================================
--- lib/Serialization/ASTWriter.cpp
+++ lib/Serialization/ASTWriter.cpp
@@ -5126,12 +5126,29 @@
 
       switch (Kind) {
       case UPD_CXX_ADDED_IMPLICIT_MEMBER:
-      case UPD_CXX_ADDED_TEMPLATE_SPECIALIZATION:
       case UPD_CXX_ADDED_ANONYMOUS_NAMESPACE:
         assert(Update.getDecl() && "no decl to add?");
         Record.push_back(GetDeclRef(Update.getDecl()));
         break;
-
+      case UPD_CXX_ADDED_TEMPLATE_SPECIALIZATION: {
+        const Decl *Spec = Update.getDecl();
+        assert(Spec && "no decl to add?");
+        Record.push_back(GetDeclRef(Spec));
+        ArrayRef<TemplateArgument> Args;
+        if (auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(Spec))
+          Args = CTSD->getTemplateArgs().asArray();
+        else if (auto *VTSD = dyn_cast<VarTemplateSpecializationDecl>(Spec))
+          Args = VTSD->getTemplateArgs().asArray();
+        else if (auto *FD = dyn_cast<FunctionDecl>(Spec))
+          Args = FD->getTemplateSpecializationArgs()->asArray();
+        assert(Args.size());
+        Record.push_back(TemplateArgumentList::ComputeODRHash(Args));
+        bool IsPartialSpecialization
+          = isa<ClassTemplatePartialSpecializationDecl>(Spec) ||
+          isa<VarTemplatePartialSpecializationDecl>(Spec);
+        Record.push_back(IsPartialSpecialization);
+        break;
+      }
       case UPD_CXX_ADDED_FUNCTION_DEFINITION:
         break;
 
Index: lib/Serialization/ASTReaderDecl.cpp
===================================================================
--- lib/Serialization/ASTReaderDecl.cpp
+++ lib/Serialization/ASTReaderDecl.cpp
@@ -40,6 +40,8 @@
     const DeclID ThisDeclID;
     const SourceLocation ThisDeclLoc;
     typedef ASTReader::RecordData RecordData;
+    using LazySpecializationInfo
+      = RedeclarableTemplateDecl::LazySpecializationInfo;
     TypeID TypeIDForTypeDecl;
     unsigned AnonymousDeclNumber;
     GlobalDeclID NamedDeclForTagDecl;
@@ -85,9 +87,16 @@
       return Record.readString();
     }
 
-    void ReadDeclIDList(SmallVectorImpl<DeclID> &IDs) {
+    LazySpecializationInfo ReadLazySpecializationInfo() {
+      DeclID ID = ReadDeclID();
+      unsigned Hash = Record.readInt();
+      bool IsPartial = Record.readInt();
+      return LazySpecializationInfo(ID, Hash, IsPartial);
+    }
+
+    void ReadDeclIDList(SmallVectorImpl<LazySpecializationInfo> &IDs) {
       for (unsigned I = 0, Size = Record.readInt(); I != Size; ++I)
-        IDs.push_back(ReadDeclID());
+        IDs.push_back(ReadLazySpecializationInfo());
     }
 
     Decl *ReadDecl() {
@@ -221,7 +230,7 @@
 
     template <typename T> static
     void AddLazySpecializations(T *D,
-                                SmallVectorImpl<serialization::DeclID>& IDs) {
+                                SmallVectorImpl<LazySpecializationInfo>& IDs) {
       if (IDs.empty())
         return;
 
@@ -231,12 +240,11 @@
       auto *&LazySpecializations = D->getCommonPtr()->LazySpecializations;
 
       if (auto &Old = LazySpecializations) {
-        IDs.insert(IDs.end(), Old + 1, Old + 1 + Old[0]);
+        IDs.insert(IDs.end(), Old + 1, Old + 1 + Old[0].DeclID);
         std::sort(IDs.begin(), IDs.end());
         IDs.erase(std::unique(IDs.begin(), IDs.end()), IDs.end());
       }
-
-      auto *Result = new (C) serialization::DeclID[1 + IDs.size()];
+      auto *Result = new (C) LazySpecializationInfo[1 + IDs.size()];
       *Result = IDs.size();
       std::copy(IDs.begin(), IDs.end(), Result + 1);
 
@@ -271,7 +279,7 @@
     void ReadFunctionDefinition(FunctionDecl *FD);
     void Visit(Decl *D);
 
-    void UpdateDecl(Decl *D, llvm::SmallVectorImpl<serialization::DeclID>&);
+    void UpdateDecl(Decl *D, llvm::SmallVectorImpl<LazySpecializationInfo>&);
 
     static void setNextObjCCategory(ObjCCategoryDecl *Cat,
                                     ObjCCategoryDecl *Next) {
@@ -2020,7 +2028,7 @@
   if (ThisDeclID == Redecl.getFirstID()) {
     // This ClassTemplateDecl owns a CommonPtr; read it to keep track of all of
     // the specializations.
-    SmallVector<serialization::DeclID, 32> SpecIDs;
+    SmallVector<LazySpecializationInfo, 32> SpecIDs;
     ReadDeclIDList(SpecIDs);
     ASTDeclReader::AddLazySpecializations(D, SpecIDs);
   }
@@ -2047,7 +2055,7 @@
   if (ThisDeclID == Redecl.getFirstID()) {
     // This VarTemplateDecl owns a CommonPtr; read it to keep track of all of
     // the specializations.
-    SmallVector<serialization::DeclID, 32> SpecIDs;
+    SmallVector<LazySpecializationInfo, 32> SpecIDs;
     ReadDeclIDList(SpecIDs);
     ASTDeclReader::AddLazySpecializations(D, SpecIDs);
   }
@@ -2153,7 +2161,7 @@
 
   if (ThisDeclID == Redecl.getFirstID()) {
     // This FunctionTemplateDecl owns a CommonPtr; read it.
-    SmallVector<serialization::DeclID, 32> SpecIDs;
+    SmallVector<LazySpecializationInfo, 32> SpecIDs;
     ReadDeclIDList(SpecIDs);
     ASTDeclReader::AddLazySpecializations(D, SpecIDs);
   }
@@ -3723,7 +3731,9 @@
   ProcessingUpdatesRAIIObj ProcessingUpdates(*this);
   DeclUpdateOffsetsMap::iterator UpdI = DeclUpdateOffsets.find(ID);
 
-  llvm::SmallVector<serialization::DeclID, 8> PendingLazySpecializationIDs;
+  using LazySpecializationInfo
+    = RedeclarableTemplateDecl::LazySpecializationInfo;
+  llvm::SmallVector<LazySpecializationInfo, 8> PendingLazySpecializationIDs;
 
   if (UpdI != DeclUpdateOffsets.end()) {
     auto UpdateOffsets = std::move(UpdI->second);
@@ -3969,7 +3979,7 @@
 }
 
 void ASTDeclReader::UpdateDecl(Decl *D,
-   llvm::SmallVectorImpl<serialization::DeclID> &PendingLazySpecializationIDs) {
+        SmallVectorImpl<LazySpecializationInfo> &PendingLazySpecializationIDs) {
   while (Record.getIdx() < Record.size()) {
     switch ((DeclUpdateKind)Record.readInt()) {
     case UPD_CXX_ADDED_IMPLICIT_MEMBER: {
@@ -3986,7 +3996,7 @@
 
     case UPD_CXX_ADDED_TEMPLATE_SPECIALIZATION:
       // It will be added to the template's lazy specialization set.
-      PendingLazySpecializationIDs.push_back(ReadDeclID());
+      PendingLazySpecializationIDs.push_back(ReadLazySpecializationInfo());
       break;
 
     case UPD_CXX_ADDED_ANONYMOUS_NAMESPACE: {
Index: lib/Serialization/ASTReader.cpp
===================================================================
--- lib/Serialization/ASTReader.cpp
+++ lib/Serialization/ASTReader.cpp
@@ -7001,14 +7001,23 @@
     }
   }
 
-  if (auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(D))
-    CTSD->getSpecializedTemplate()->LoadLazySpecializations();
-  if (auto *VTSD = dyn_cast<VarTemplateSpecializationDecl>(D))
-    VTSD->getSpecializedTemplate()->LoadLazySpecializations();
-  if (auto *FD = dyn_cast<FunctionDecl>(D)) {
-    if (auto *Template = FD->getPrimaryTemplate())
-      Template->LoadLazySpecializations();
-  }
+  RedeclarableTemplateDecl *Template = nullptr;
+  ArrayRef<TemplateArgument> Args;
+  if (auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(D)) {
+    Template = CTSD->getSpecializedTemplate();
+    Args = CTSD->getTemplateArgs().asArray();
+  } else if (auto *VTSD = dyn_cast<VarTemplateSpecializationDecl>(D)) {
+    Template = VTSD->getSpecializedTemplate();
+    Args = VTSD->getTemplateArgs().asArray();
+  } else if (auto *FD = dyn_cast<FunctionDecl>(D)) {
+    if (auto *Tmplt = FD->getPrimaryTemplate()) {
+      Template = Tmplt;
+      Args = FD->getTemplateSpecializationArgs()->asArray();
+    }
+  }
+
+  if (Template)
+    Template->loadLazySpecializationsImpl(Args);
 }
 
 CXXCtorInitializer **
Index: lib/AST/ODRHash.cpp
===================================================================
--- lib/AST/ODRHash.cpp
+++ lib/AST/ODRHash.cpp
@@ -131,30 +131,34 @@
 void ODRHash::AddTemplateArgument(TemplateArgument TA) {
   const auto Kind = TA.getKind();
   ID.AddInteger(Kind);
-
   switch (Kind) {
-    case TemplateArgument::Null:
-      llvm_unreachable("Expected valid TemplateArgument");
-    case TemplateArgument::Type:
-      AddQualType(TA.getAsType());
-      break;
-    case TemplateArgument::Declaration:
-    case TemplateArgument::NullPtr:
-    case TemplateArgument::Integral:
-      break;
-    case TemplateArgument::Template:
-    case TemplateArgument::TemplateExpansion:
-      AddTemplateName(TA.getAsTemplateOrTemplatePattern());
-      break;
-    case TemplateArgument::Expression:
-      AddStmt(TA.getAsExpr());
-      break;
-    case TemplateArgument::Pack:
-      ID.AddInteger(TA.pack_size());
-      for (auto SubTA : TA.pack_elements()) {
-        AddTemplateArgument(SubTA);
-      }
-      break;
+  case TemplateArgument::Null:
+    llvm_unreachable("Require valid TemplateArgument");
+  case TemplateArgument::Type:
+    AddQualType(TA.getAsType());
+    break;
+  case TemplateArgument::Declaration:
+    AddDecl(TA.getAsDecl());
+    break;
+  case TemplateArgument::NullPtr:
+    AddQualType(TA.getNullPtrType());
+    break;
+  case TemplateArgument::Integral:
+    TA.getAsIntegral().Profile(ID);
+    AddQualType(TA.getIntegralType());
+    break;
+  case TemplateArgument::Template:
+  case TemplateArgument::TemplateExpansion:
+    AddTemplateName(TA.getAsTemplateOrTemplatePattern());
+    break;
+  case TemplateArgument::Expression:
+    AddStmt(TA.getAsExpr());
+    break;
+  case TemplateArgument::Pack:
+    ID.AddInteger(TA.pack_size());
+    for (auto SubTA : TA.pack_elements())
+      AddTemplateArgument(SubTA);
+    break;
   }
 }
 
@@ -515,6 +519,21 @@
   if (const NamedDecl *ND = dyn_cast<NamedDecl>(D)) {
     AddDeclarationName(ND->getDeclName());
   }
+
+  // If this was a specialization we should take into account its template
+  // arguments. This helps to reduce collisions coming when visiting template
+  // specialization types (eg. when processing type template arguments).
+  ArrayRef<TemplateArgument> Args;
+  if (auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(D))
+    Args = CTSD->getTemplateArgs().asArray();
+  else if (auto *VTSD = dyn_cast<VarTemplateSpecializationDecl>(D))
+    Args = VTSD->getTemplateArgs().asArray();
+  else if (auto *FD = dyn_cast<FunctionDecl>(D))
+    if (FD->getTemplateSpecializationArgs())
+      Args = FD->getTemplateSpecializationArgs()->asArray();
+
+  for (auto &TA : Args)
+    AddTemplateArgument(TA);
 }
 
 namespace {
Index: lib/AST/DeclTemplate.cpp
===================================================================
--- lib/AST/DeclTemplate.cpp
+++ lib/AST/DeclTemplate.cpp
@@ -17,6 +17,7 @@
 #include "clang/AST/DeclCXX.h"
 #include "clang/AST/DeclarationName.h"
 #include "clang/AST/Expr.h"
+#include "clang/AST/ODRHash.h"
 #include "clang/AST/TemplateBase.h"
 #include "clang/AST/TemplateName.h"
 #include "clang/AST/Type.h"
@@ -182,30 +183,71 @@
   return Common;
 }
 
-void RedeclarableTemplateDecl::loadLazySpecializationsImpl() const {
+void RedeclarableTemplateDecl::loadLazySpecializationsImpl(
+                                             bool OnlyPartial/*=false*/) const {
   // Grab the most recent declaration to ensure we've loaded any lazy
   // redeclarations of this template.
   CommonBase *CommonBasePtr = getMostRecentDecl()->getCommonPtr();
-  if (CommonBasePtr->LazySpecializations) {
-    ASTContext &Context = getASTContext();
-    uint32_t *Specs = CommonBasePtr->LazySpecializations;
-    CommonBasePtr->LazySpecializations = nullptr;
-    for (uint32_t I = 0, N = *Specs++; I != N; ++I)
-      (void)Context.getExternalSource()->GetExternalDecl(Specs[I]);
+  if (auto *Specs = CommonBasePtr->LazySpecializations) {
+    if (!OnlyPartial)
+      CommonBasePtr->LazySpecializations = nullptr;
+    for (uint32_t I = 0, N = Specs[0].DeclID; I != N; ++I) {
+      // Skip over already loaded specializations.
+      if (!Specs[I+1].ODRHash)
+        continue;
+      if (!OnlyPartial || Specs[I+1].IsPartial)
+        (void)loadLazySpecializationImpl(Specs[I+1]);
+    }
   }
 }
 
+Decl *RedeclarableTemplateDecl::loadLazySpecializationImpl(
+                                   LazySpecializationInfo &LazySpecInfo) const {
+  uint32_t ID = LazySpecInfo.DeclID;
+  assert(ID && "Loading already loaded specialization!");
+  // Note that we loaded the specialization.
+  LazySpecInfo.DeclID = LazySpecInfo.ODRHash = LazySpecInfo.IsPartial = 0;
+  return getASTContext().getExternalSource()->GetExternalDecl(ID);
+}
+
+bool
+RedeclarableTemplateDecl::loadLazySpecializationsImpl(ArrayRef<TemplateArgument>
+                                                      Args) const {
+  bool LoadedSpecialization = false;
+  CommonBase *CommonBasePtr = getMostRecentDecl()->getCommonPtr();
+  if (auto *Specs = CommonBasePtr->LazySpecializations) {
+    unsigned Hash = TemplateArgumentList::ComputeODRHash(Args);
+    for (uint32_t I = 0, N = Specs[0].DeclID; I != N; ++I)
+      if (Specs[I+1].ODRHash && Specs[I+1].ODRHash == Hash)
+        LoadedSpecialization |= (bool)loadLazySpecializationImpl(Specs[I+1]);
+  }
+  return LoadedSpecialization;
+ }
+
 template<class EntryType>
 typename RedeclarableTemplateDecl::SpecEntryTraits<EntryType>::DeclType *
 RedeclarableTemplateDecl::findSpecializationImpl(
     llvm::FoldingSetVector<EntryType> &Specs, ArrayRef<TemplateArgument> Args,
     void *&InsertPos) {
   using SETraits = SpecEntryTraits<EntryType>;
 
+  (void)loadLazySpecializationsImpl(Args);
+
   llvm::FoldingSetNodeID ID;
   EntryType::Profile(ID, Args, getASTContext());
   EntryType *Entry = Specs.FindNodeOrInsertPos(ID, InsertPos);
-  return Entry ? SETraits::getDecl(Entry)->getMostRecentDecl() : nullptr;
+  if (!Entry)
+    return nullptr;
+  auto *D = SETraits::getDecl(Entry)->getMostRecentDecl();
+#ifndef NDEBUG
+  // Let's make sure that the hash template arguments match the hash of the ones
+  // in the found specialization.
+  unsigned ArgHash = TemplateArgumentList::ComputeODRHash(Args);
+  unsigned DArgHash
+     = TemplateArgumentList::ComputeODRHash(SETraits::getTemplateArgs(Entry));
+  assert (ArgHash == DArgHash && "Hashes differ!");
+#endif // NDEBUG
+  return D;
 }
 
 template<class Derived, class EntryType>
@@ -216,10 +258,11 @@
 
   if (InsertPos) {
 #ifndef NDEBUG
+    auto Args = SETraits::getTemplateArgs(Entry);
+    assert(!loadLazySpecializationsImpl(Args) &&
+           "Specialization is already registered as lazy");
     void *CorrectInsertPos;
-    assert(!findSpecializationImpl(Specializations,
-                                   SETraits::getTemplateArgs(Entry),
-                                   CorrectInsertPos) &&
+    assert(!findSpecializationImpl(Specializations, Args, CorrectInsertPos) &&
            InsertPos == CorrectInsertPos &&
            "given incorrect InsertPos for specialization");
 #endif
@@ -276,12 +319,14 @@
 FunctionDecl *
 FunctionTemplateDecl::findSpecialization(ArrayRef<TemplateArgument> Args,
                                          void *&InsertPos) {
-  return findSpecializationImpl(getSpecializations(), Args, InsertPos);
+  auto *Common = getCommonPtr();
+  return findSpecializationImpl(Common->Specializations, Args, InsertPos);
 }
 
 void FunctionTemplateDecl::addSpecialization(
       FunctionTemplateSpecializationInfo *Info, void *InsertPos) {
-  addSpecializationImpl<FunctionTemplateDecl>(getSpecializations(), Info,
+  auto *Common = getCommonPtr();
+  addSpecializationImpl<FunctionTemplateDecl>(Common->Specializations, Info,
                                               InsertPos);
 }
 
@@ -331,8 +376,9 @@
                                        DeclarationName(), nullptr, nullptr);
 }
 
-void ClassTemplateDecl::LoadLazySpecializations() const {
-  loadLazySpecializationsImpl();
+void ClassTemplateDecl::LoadLazySpecializations(
+                                             bool OnlyPartial/*=false*/) const {
+  loadLazySpecializationsImpl(OnlyPartial);
 }
 
 llvm::FoldingSetVector<ClassTemplateSpecializationDecl> &
@@ -343,7 +389,7 @@
 
 llvm::FoldingSetVector<ClassTemplatePartialSpecializationDecl> &
 ClassTemplateDecl::getPartialSpecializations() {
-  LoadLazySpecializations();
+  LoadLazySpecializations(/*PartialOnly = */ true);
   return getCommonPtr()->PartialSpecializations;
 }  
 
@@ -357,12 +403,15 @@
 ClassTemplateSpecializationDecl *
 ClassTemplateDecl::findSpecialization(ArrayRef<TemplateArgument> Args,
                                       void *&InsertPos) {
-  return findSpecializationImpl(getSpecializations(), Args, InsertPos);
+  auto *Common = getCommonPtr();
+  return findSpecializationImpl(Common->Specializations, Args, InsertPos);
 }
 
 void ClassTemplateDecl::AddSpecialization(ClassTemplateSpecializationDecl *D,
                                           void *InsertPos) {
-  addSpecializationImpl<ClassTemplateDecl>(getSpecializations(), D, InsertPos);
+  auto *Common = getCommonPtr();
+  addSpecializationImpl<ClassTemplateDecl>(Common->Specializations, D,
+                                           InsertPos);
 }
 
 ClassTemplatePartialSpecializationDecl *
@@ -653,6 +702,14 @@
   return new (Mem) TemplateArgumentList(Args);
 }
 
+unsigned TemplateArgumentList::ComputeODRHash(ArrayRef<TemplateArgument> Args) {
+  ODRHash Hasher;
+  for (TemplateArgument TA : Args)
+    Hasher.AddTemplateArgument(TA);
+
+  return Hasher.CalculateHash();
+}
+
 FunctionTemplateSpecializationInfo *
 FunctionTemplateSpecializationInfo::Create(ASTContext &C, FunctionDecl *FD,
                                            FunctionTemplateDecl *Template,
@@ -932,8 +989,9 @@
                                      DeclarationName(), nullptr, nullptr);
 }
 
-void VarTemplateDecl::LoadLazySpecializations() const {
-  loadLazySpecializationsImpl();
+void VarTemplateDecl::LoadLazySpecializations(
+                                             bool OnlyPartial/*=false*/) const {
+  loadLazySpecializationsImpl(OnlyPartial);
 }
 
 llvm::FoldingSetVector<VarTemplateSpecializationDecl> &
@@ -944,7 +1002,7 @@
 
 llvm::FoldingSetVector<VarTemplatePartialSpecializationDecl> &
 VarTemplateDecl::getPartialSpecializations() {
-  LoadLazySpecializations();
+  LoadLazySpecializations(/*PartialOnly = */ true);
   return getCommonPtr()->PartialSpecializations;
 }
 
@@ -958,12 +1016,14 @@
 VarTemplateSpecializationDecl *
 VarTemplateDecl::findSpecialization(ArrayRef<TemplateArgument> Args,
                                     void *&InsertPos) {
-  return findSpecializationImpl(getSpecializations(), Args, InsertPos);
+  auto *Common = getCommonPtr();
+  return findSpecializationImpl(Common->Specializations, Args, InsertPos);
 }
 
 void VarTemplateDecl::AddSpecialization(VarTemplateSpecializationDecl *D,
                                         void *InsertPos) {
-  addSpecializationImpl<VarTemplateDecl>(getSpecializations(), D, InsertPos);
+  auto *Common = getCommonPtr();
+  addSpecializationImpl<VarTemplateDecl>(Common->Specializations, D, InsertPos);
 }
 
 VarTemplatePartialSpecializationDecl *
Index: include/clang/AST/DeclTemplate.h
===================================================================
--- include/clang/AST/DeclTemplate.h
+++ include/clang/AST/DeclTemplate.h
@@ -235,6 +235,9 @@
   static TemplateArgumentList *CreateCopy(ASTContext &Context,
                                           ArrayRef<TemplateArgument> Args);
 
+  /// \brief Create hash for the given arguments.
+  static unsigned ComputeODRHash(ArrayRef<TemplateArgument> Args);
+
   /// \brief Construct a new, temporary template argument list on the stack.
   ///
   /// The template argument list does not own the template arguments
@@ -750,6 +753,26 @@
     return getMostRecentDecl();
   }
 
+  struct LazySpecializationInfo {
+    uint32_t DeclID = ~0U;
+    unsigned ODRHash = ~0U;
+    bool IsPartial = false;
+    LazySpecializationInfo(uint32_t ID, unsigned Hash = ~0U,
+                           bool Partial = false)
+      : DeclID(ID), ODRHash(Hash), IsPartial(Partial) { }
+    LazySpecializationInfo() { }
+    bool operator<(const LazySpecializationInfo &Other) const {
+      return DeclID < Other.DeclID;
+    }
+    bool operator==(const LazySpecializationInfo &Other) const {
+      assert((DeclID != Other.DeclID || ODRHash == Other.ODRHash) &&
+             "Hashes differ!");
+      assert((DeclID != Other.DeclID || IsPartial == Other.IsPartial) &&
+             "Both must be the same kinds!");
+      return DeclID == Other.DeclID;
+    }
+  };
+
 protected:
   template <typename EntryType> struct SpecEntryTraits {
     using DeclType = EntryType;
@@ -790,7 +813,12 @@
     return SpecIterator<EntryType>(isEnd ? Specs.end() : Specs.begin());
   }
 
-  void loadLazySpecializationsImpl() const;
+  void loadLazySpecializationsImpl(bool OnlyPartial = false) const;
+
+  ///\returns true if any lazy specialization was loaded.
+  bool loadLazySpecializationsImpl(llvm::ArrayRef<TemplateArgument> Args) const;
+
+  Decl *loadLazySpecializationImpl(LazySpecializationInfo &LazySpecInfo) const;
 
   template <class EntryType> typename SpecEntryTraits<EntryType>::DeclType*
   findSpecializationImpl(llvm::FoldingSetVector<EntryType> &Specs,
@@ -816,7 +844,7 @@
     ///
     /// The first value in the array is the number of specializations/partial
     /// specializations that follow.
-    uint32_t *LazySpecializations = nullptr;
+    LazySpecializationInfo *LazySpecializations = nullptr;
   };
 
   /// \brief Pointer to the common data shared by all declarations of this
@@ -2101,7 +2129,7 @@
   friend class ASTDeclWriter;
 
   /// \brief Load any lazily-loaded specializations from the external source.
-  void LoadLazySpecializations() const;
+  void LoadLazySpecializations(bool OnlyPartial = false) const;
 
   /// \brief Get the underlying class declarations of the template.
   CXXRecordDecl *getTemplatedDecl() const {
@@ -2908,7 +2936,7 @@
   friend class ASTDeclWriter;
 
   /// \brief Load any lazily-loaded specializations from the external source.
-  void LoadLazySpecializations() const;
+  void LoadLazySpecializations(bool OnlyPartial = false) const;
 
   /// \brief Get the underlying variable declarations of the template.
   VarDecl *getTemplatedDecl() const {
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to