================
@@ -839,100 +842,84 @@ static void 
updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags,
       const Type *FieldTy = FD->getType().getTypePtr();
       if (const HLSLAttributedResourceType *AttrResType =
               dyn_cast<HLSLAttributedResourceType>(FieldTy)) {
-        updateResourceClassFlagsFromDeclResourceClass(
-            Flags, AttrResType->getAttrs().ResourceClass);
-        continue;
+        ResourceClass RC = AttrResType->getAttrs().ResourceClass;
+        if (getRegisterType(RC) == RegType)
+          return true;
+      } else {
+        TypesToScan.emplace_back(FD->getType().getTypePtr());
       }
-      TypesToScan.emplace_back(FD->getType().getTypePtr());
     }
   }
+  return false;
 }
 
-static RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S,
-                                                         Decl *TheDecl) {
-  RegisterBindingFlags Flags;
+static void CheckContainsResourceForRegisterType(Sema &S,
+                                                 SourceLocation &ArgLoc,
+                                                 Decl *D, RegisterType RegType,
+                                                 bool SpecifiedSpace) {
+  int RegTypeNum = static_cast<int>(RegType);
 
   // check if the decl type is groupshared
-  if (TheDecl->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
-    Flags.Other = true;
-    return Flags;
+  if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
+    S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+    return;
   }
 
   // Cbuffers and Tbuffers are HLSLBufferDecl types
-  if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl)) {
-    Flags.Resource = true;
-    Flags.ResourceClass = CBufferOrTBuffer->isCBuffer()
-                              ? llvm::dxil::ResourceClass::CBuffer
-                              : llvm::dxil::ResourceClass::SRV;
+  if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
+    ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer
+                                                     : ResourceClass::SRV;
+    if (RegType != getRegisterType(RC))
+      S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
+          << RegTypeNum;
+    return;
   }
+
   // Samplers, UAVs, and SRVs are VarDecl types
-  else if (VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl)) {
-    if (const HLSLAttributedResourceType *AttrResType =
-            findAttributedResourceTypeOnField(TheVarDecl)) {
-      Flags.Resource = true;
-      Flags.ResourceClass = AttrResType->getAttrs().ResourceClass;
-    } else {
-      const clang::Type *TheBaseType = TheVarDecl->getType().getTypePtr();
-      while (TheBaseType->isArrayType())
-        TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
-
-      if (TheBaseType->isArithmeticType()) {
-        Flags.Basic = true;
-        if (!isDeclaredWithinCOrTBuffer(TheDecl) &&
-            (TheBaseType->isIntegralType(S.getASTContext()) ||
-             TheBaseType->isFloatingType()))
-          Flags.DefaultGlobals = true;
-      } else if (TheBaseType->isRecordType()) {
-        Flags.UDT = true;
-        const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>();
-        updateResourceClassFlagsFromRecordType(Flags, TheRecordTy);
-      } else
-        Flags.Other = true;
-    }
-  } else {
-    llvm_unreachable("expected be VarDecl or HLSLBufferDecl");
+  assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl");
+  VarDecl *VD = cast<VarDecl>(D);
----------------
bogner wrote:

FWIW `cast<>` will assert if the cast fails. It's debatable whether or not the 
slight amount of extra information from the specific message is worth its own 
assert, so this doesn't necessarily need to be changed, but I figured I'd point 
that out.

https://github.com/llvm/llvm-project/pull/108924
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to