llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-backend-directx Author: None (joaosaffran) <details> <summary>Changes</summary> This PR changes the validation logic for Root Descriptor and Descriptor Range flags to properly check if the `uint32_t` values are within range before casting into the enums. --- Full diff: https://github.com/llvm/llvm-project/pull/161587.diff 6 Files Affected: - (modified) clang/lib/Sema/SemaHLSL.cpp (+2-2) - (modified) llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h (+1-1) - (modified) llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp (+1-2) - (modified) llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp (+11-1) - (added) llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-Flag-LargeNumber.ll (+20) - (added) llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-Flags-LargeNumber.ll (+18) ``````````diff diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 129b03c07c0bd..a2e8afb9bb8ff 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -1322,8 +1322,8 @@ bool SemaHLSL::handleRootSignatureElements( ReportError(Loc, 1, 0xfffffffe); } - if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(Version, Clause->Type, - Clause->Flags)) + if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag( + Version, Clause->Type, llvm::to_underlying(Clause->Flags))) ReportFlagError(Loc); } } diff --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h index 4dd18111b0c9d..10723a181f025 100644 --- a/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h +++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h @@ -32,7 +32,7 @@ LLVM_ABI bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal); LLVM_ABI bool verifyRangeType(uint32_t Type); LLVM_ABI bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type, - dxbc::DescriptorRangeFlags FlagsVal); + uint32_t FlagsVal); LLVM_ABI bool verifyStaticSamplerFlags(uint32_t Version, uint32_t FlagsNumber); LLVM_ABI bool verifyNumDescriptors(uint32_t NumDescriptors); LLVM_ABI bool verifyMipLODBias(float MipLODBias); diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp index 5785505ce2b0c..2a22364563f90 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp @@ -665,8 +665,7 @@ Error MetadataParser::validateRootSignature( "NumDescriptors", Range.NumDescriptors)); if (!hlsl::rootsig::verifyDescriptorRangeFlag( - RSD.Version, Range.RangeType, - dxbc::DescriptorRangeFlags(Range.Flags))) + RSD.Version, Range.RangeType, Range.Flags)) DeferredErrs = joinErrors(std::move(DeferredErrs), make_error<RootSignatureValidationError<uint32_t>>( diff --git a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp index 2c78d622f7f28..e887906955dd2 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp @@ -36,6 +36,11 @@ bool verifyRegisterSpace(uint32_t RegisterSpace) { bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) { using FlagT = dxbc::RootDescriptorFlags; + uint32_t LargestValue = + llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR); + if (FlagsVal >= NextPowerOf2(LargestValue)) + return false; + FlagT Flags = FlagT(FlagsVal); if (Version == 1) return Flags == FlagT::DataVolatile; @@ -54,9 +59,14 @@ bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) { } bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type, - dxbc::DescriptorRangeFlags Flags) { + uint32_t FlagsVal) { using FlagT = dxbc::DescriptorRangeFlags; + uint32_t LargestValue = + llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR); + if (FlagsVal >= NextPowerOf2(LargestValue)) + return false; + FlagT Flags = FlagT(FlagsVal); const bool IsSampler = (Type == dxil::ResourceClass::Sampler); if (Version == 1) { diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-Flag-LargeNumber.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-Flag-LargeNumber.ll new file mode 100644 index 0000000000000..c27c87ff057d5 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-Flag-LargeNumber.ll @@ -0,0 +1,20 @@ +; RUN: not opt -passes='print<dxil-root-signature>' %s -S -o - 2>&1 | FileCheck %s + +target triple = "dxil-unknown-shadermodel6.0-compute" + +; CHECK: error: Invalid value for DescriptorFlag: 66666 +; CHECK-NOT: Root Signature Definitions + +define void @main() #0 { +entry: + ret void +} +attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" } + + +!dx.rootsignatures = !{!2} ; list of function/root signature pairs +!2 = !{ ptr @main, !3, i32 2 } ; function, root signature +!3 = !{ !5 } ; list of root signature elements +!5 = !{ !"DescriptorTable", i32 0, !6, !7 } +!6 = !{ !"SRV", i32 1, i32 1, i32 0, i32 -1, i32 66666 } +!7 = !{ !"UAV", i32 5, i32 1, i32 10, i32 5, i32 2 } diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-Flags-LargeNumber.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-Flags-LargeNumber.ll new file mode 100644 index 0000000000000..898e197c7e0cc --- /dev/null +++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-Flags-LargeNumber.ll @@ -0,0 +1,18 @@ +; RUN: not opt -passes='print<dxil-root-signature>' %s -S -o - 2>&1 | FileCheck %s + +target triple = "dxil-unknown-shadermodel6.0-compute" + + +; CHECK: error: Invalid value for RootDescriptorFlag: 666 +; CHECK-NOT: Root Signature Definitions +define void @main() #0 { +entry: + ret void +} +attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" } + + +!dx.rootsignatures = !{!2} ; list of function/root signature pairs +!2 = !{ ptr @main, !3, i32 2 } ; function, root signature +!3 = !{ !5 } ; list of root signature elements +!5 = !{ !"RootCBV", i32 0, i32 1, i32 2, i32 666 } `````````` </details> https://github.com/llvm/llvm-project/pull/161587 _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
