================ @@ -839,100 +842,84 @@ static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags, const Type *FieldTy = FD->getType().getTypePtr(); if (const HLSLAttributedResourceType *AttrResType = dyn_cast<HLSLAttributedResourceType>(FieldTy)) { - updateResourceClassFlagsFromDeclResourceClass( - Flags, AttrResType->getAttrs().ResourceClass); - continue; + ResourceClass RC = AttrResType->getAttrs().ResourceClass; + if (getRegisterType(RC) == RegType) + return true; + } else { + TypesToScan.emplace_back(FD->getType().getTypePtr()); } - TypesToScan.emplace_back(FD->getType().getTypePtr()); } } + return false; } -static RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S, - Decl *TheDecl) { - RegisterBindingFlags Flags; +static void CheckContainsResourceForRegisterType(Sema &S, + SourceLocation &ArgLoc, + Decl *D, RegisterType RegType, + bool SpecifiedSpace) { + int RegTypeNum = static_cast<int>(RegType); // check if the decl type is groupshared - if (TheDecl->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) { - Flags.Other = true; - return Flags; + if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) { + S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; + return; } // Cbuffers and Tbuffers are HLSLBufferDecl types - if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl)) { - Flags.Resource = true; - Flags.ResourceClass = CBufferOrTBuffer->isCBuffer() - ? llvm::dxil::ResourceClass::CBuffer - : llvm::dxil::ResourceClass::SRV; + if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) { + ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer + : ResourceClass::SRV; + if (RegType != getRegisterType(RC)) + S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch) + << RegTypeNum; + return; } + // Samplers, UAVs, and SRVs are VarDecl types - else if (VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl)) { - if (const HLSLAttributedResourceType *AttrResType = - findAttributedResourceTypeOnField(TheVarDecl)) { - Flags.Resource = true; - Flags.ResourceClass = AttrResType->getAttrs().ResourceClass; - } else { - const clang::Type *TheBaseType = TheVarDecl->getType().getTypePtr(); - while (TheBaseType->isArrayType()) - TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual(); - - if (TheBaseType->isArithmeticType()) { - Flags.Basic = true; - if (!isDeclaredWithinCOrTBuffer(TheDecl) && - (TheBaseType->isIntegralType(S.getASTContext()) || - TheBaseType->isFloatingType())) - Flags.DefaultGlobals = true; - } else if (TheBaseType->isRecordType()) { - Flags.UDT = true; - const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>(); - updateResourceClassFlagsFromRecordType(Flags, TheRecordTy); - } else - Flags.Other = true; - } - } else { - llvm_unreachable("expected be VarDecl or HLSLBufferDecl"); + assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl"); + VarDecl *VD = cast<VarDecl>(D); ---------------- bogner wrote:
FWIW `cast<>` will assert if the cast fails. It's debatable whether or not the slight amount of extra information from the specific message is worth its own assert, so this doesn't necessarily need to be changed, but I figured I'd point that out. https://github.com/llvm/llvm-project/pull/108924 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits