llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-hlsl Author: Finn Plummer (inbelic) <details> <summary>Changes</summary> Implements serialization of the remaining completely defined `RootElement`s, namely `RootDescriptor`s and `RootFlag`s. - Adds unit testing for the serialization methods Resolves https://github.com/llvm/llvm-project/issues/138191 Resolves https://github.com/llvm/llvm-project/issues/138193 --- Full diff: https://github.com/llvm/llvm-project/pull/143198.diff 3 Files Affected: - (modified) llvm/include/llvm/Frontend/HLSL/HLSLRootSignatureUtils.h (+6) - (modified) llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp (+254) - (modified) llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp (+121) ``````````diff diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignatureUtils.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignatureUtils.h index ca20e6719f3a4..7489777670703 100644 --- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignatureUtils.h +++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignatureUtils.h @@ -32,6 +32,12 @@ LLVM_ABI raw_ostream &operator<<(raw_ostream &OS, const RootFlags &Flags); LLVM_ABI raw_ostream &operator<<(raw_ostream &OS, const RootConstants &Constants); +LLVM_ABI raw_ostream &operator<<(raw_ostream &OS, + const RootDescriptor &Descriptor); + +LLVM_ABI raw_ostream &operator<<(raw_ostream &OS, + const StaticSampler &StaticSampler); + LLVM_ABI raw_ostream &operator<<(raw_ostream &OS, const DescriptorTableClause &Clause); diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp index 24486a55ecf6a..70c3e72c1f806 100644 --- a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp +++ b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp @@ -71,6 +71,199 @@ static raw_ostream &operator<<(raw_ostream &OS, return OS; } +static raw_ostream &operator<<(raw_ostream &OS, const SamplerFilter &Filter) { + switch (Filter) { + case SamplerFilter::MinMagMipPoint: + OS << "MinMagMipPoint"; + break; + case SamplerFilter::MinMagPointMipLinear: + OS << "MinMagPointMipLinear"; + break; + case SamplerFilter::MinPointMagLinearMipPoint: + OS << "MinPointMagLinearMipPoint"; + break; + case SamplerFilter::MinPointMagMipLinear: + OS << "MinPointMagMipLinear"; + break; + case SamplerFilter::MinLinearMagMipPoint: + OS << "MinLinearMagMipPoint"; + break; + case SamplerFilter::MinLinearMagPointMipLinear: + OS << "MinLinearMagPointMipLinear"; + break; + case SamplerFilter::MinMagLinearMipPoint: + OS << "MinMagLinearMipPoint"; + break; + case SamplerFilter::MinMagMipLinear: + OS << "MinMagMipLinear"; + break; + case SamplerFilter::Anisotropic: + OS << "Anisotropic"; + break; + case SamplerFilter::ComparisonMinMagMipPoint: + OS << "ComparisonMinMagMipPoint"; + break; + case SamplerFilter::ComparisonMinMagPointMipLinear: + OS << "ComparisonMinMagPointMipLinear"; + break; + case SamplerFilter::ComparisonMinPointMagLinearMipPoint: + OS << "ComparisonMinPointMagLinearMipPoint"; + break; + case SamplerFilter::ComparisonMinPointMagMipLinear: + OS << "ComparisonMinPointMagMipLinear"; + break; + case SamplerFilter::ComparisonMinLinearMagMipPoint: + OS << "ComparisonMinLinearMagMipPoint"; + break; + case SamplerFilter::ComparisonMinLinearMagPointMipLinear: + OS << "ComparisonMinLinearMagPointMipLinear"; + break; + case SamplerFilter::ComparisonMinMagLinearMipPoint: + OS << "ComparisonMinMagLinearMipPoint"; + break; + case SamplerFilter::ComparisonMinMagMipLinear: + OS << "ComparisonMinMagMipLinear"; + break; + case SamplerFilter::ComparisonAnisotropic: + OS << "ComparisonAnisotropic"; + break; + case SamplerFilter::MinimumMinMagMipPoint: + OS << "MinimumMinMagMipPoint"; + break; + case SamplerFilter::MinimumMinMagPointMipLinear: + OS << "MinimumMinMagPointMipLinear"; + break; + case SamplerFilter::MinimumMinPointMagLinearMipPoint: + OS << "MinimumMinPointMagLinearMipPoint"; + break; + case SamplerFilter::MinimumMinPointMagMipLinear: + OS << "MinimumMinPointMagMipLinear"; + break; + case SamplerFilter::MinimumMinLinearMagMipPoint: + OS << "MinimumMinLinearMagMipPoint"; + break; + case SamplerFilter::MinimumMinLinearMagPointMipLinear: + OS << "MinimumMinLinearMagPointMipLinear"; + break; + case SamplerFilter::MinimumMinMagLinearMipPoint: + OS << "MinimumMinMagLinearMipPoint"; + break; + case SamplerFilter::MinimumMinMagMipLinear: + OS << "MinimumMinMagMipLinear"; + break; + case SamplerFilter::MinimumAnisotropic: + OS << "MinimumAnisotropic"; + break; + case SamplerFilter::MaximumMinMagMipPoint: + OS << "MaximumMinMagMipPoint"; + break; + case SamplerFilter::MaximumMinMagPointMipLinear: + OS << "MaximumMinMagPointMipLinear"; + break; + case SamplerFilter::MaximumMinPointMagLinearMipPoint: + OS << "MaximumMinPointMagLinearMipPoint"; + break; + case SamplerFilter::MaximumMinPointMagMipLinear: + OS << "MaximumMinPointMagMipLinear"; + break; + case SamplerFilter::MaximumMinLinearMagMipPoint: + OS << "MaximumMinLinearMagMipPoint"; + break; + case SamplerFilter::MaximumMinLinearMagPointMipLinear: + OS << "MaximumMinLinearMagPointMipLinear"; + break; + case SamplerFilter::MaximumMinMagLinearMipPoint: + OS << "MaximumMinMagLinearMipPoint"; + break; + case SamplerFilter::MaximumMinMagMipLinear: + OS << "MaximumMinMagMipLinear"; + break; + case SamplerFilter::MaximumAnisotropic: + OS << "MaximumAnisotropic"; + break; + } + + return OS; +} + +static raw_ostream &operator<<(raw_ostream &OS, + const TextureAddressMode &Address) { + switch (Address) { + case TextureAddressMode::Wrap: + OS << "Wrap"; + break; + case TextureAddressMode::Mirror: + OS << "Mirror"; + break; + case TextureAddressMode::Clamp: + OS << "Clamp"; + break; + case TextureAddressMode::Border: + OS << "Border"; + break; + case TextureAddressMode::MirrorOnce: + OS << "MirrorOnce"; + break; + } + + return OS; +} + +static raw_ostream &operator<<(raw_ostream &OS, + const ComparisonFunc &CompFunc) { + switch (CompFunc) { + case ComparisonFunc::Never: + OS << "Never"; + break; + case ComparisonFunc::Less: + OS << "Less"; + break; + case ComparisonFunc::Equal: + OS << "Equal"; + break; + case ComparisonFunc::LessEqual: + OS << "LessEqual"; + break; + case ComparisonFunc::Greater: + OS << "Greater"; + break; + case ComparisonFunc::NotEqual: + OS << "NotEqual"; + break; + case ComparisonFunc::GreaterEqual: + OS << "GreaterEqual"; + break; + case ComparisonFunc::Always: + OS << "Always"; + break; + } + + return OS; +} + +static raw_ostream &operator<<(raw_ostream &OS, + const StaticBorderColor &BorderColor) { + switch (BorderColor) { + case StaticBorderColor::TransparentBlack: + OS << "TransparentBlack"; + break; + case StaticBorderColor::OpaqueBlack: + OS << "OpaqueBlack"; + break; + case StaticBorderColor::OpaqueWhite: + OS << "OpaqueWhite"; + break; + case StaticBorderColor::OpaqueBlackUint: + OS << "OpaqueBlackUint"; + break; + case StaticBorderColor::OpaqueWhiteUint: + OS << "OpaqueWhiteUint"; + break; + } + + return OS; +} + static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) { switch (Type) { case ClauseType::CBuffer: @@ -132,6 +325,42 @@ static raw_ostream &operator<<(raw_ostream &OS, return OS; } +static raw_ostream &operator<<(raw_ostream &OS, + const RootDescriptorFlags &Flags) { + bool FlagSet = false; + unsigned Remaining = llvm::to_underlying(Flags); + while (Remaining) { + unsigned Bit = 1u << llvm::countr_zero(Remaining); + if (Remaining & Bit) { + if (FlagSet) + OS << " | "; + + switch (static_cast<RootDescriptorFlags>(Bit)) { + case RootDescriptorFlags::DataVolatile: + OS << "DataVolatile"; + break; + case RootDescriptorFlags::DataStaticWhileSetAtExecute: + OS << "DataStaticWhileSetAtExecute"; + break; + case RootDescriptorFlags::DataStatic: + OS << "DataStatic"; + break; + default: + OS << "invalid: " << Bit; + break; + } + + FlagSet = true; + } + Remaining &= ~Bit; + } + + if (!FlagSet) + OS << "None"; + + return OS; +} + raw_ostream &operator<<(raw_ostream &OS, const RootFlags &Flags) { OS << "RootFlags("; bool FlagSet = false; @@ -205,6 +434,31 @@ raw_ostream &operator<<(raw_ostream &OS, const RootConstants &Constants) { return OS; } +raw_ostream &operator<<(raw_ostream &OS, const RootDescriptor &Descriptor) { + ClauseType Type = ClauseType(llvm::to_underlying(Descriptor.Type)); + OS << "Root" << Type << "(" << Descriptor.Reg + << ", space = " << Descriptor.Space + << ", visibility = " << Descriptor.Visibility + << ", flags = " << Descriptor.Flags << ")"; + + return OS; +} + +raw_ostream &operator<<(raw_ostream &OS, const StaticSampler &Sampler) { + OS << "StaticSampler(" << Sampler.Reg << ", filter = " << Sampler.Filter + << ", addressU = " << Sampler.AddressU + << ", addressV = " << Sampler.AddressV + << ", addressW = " << Sampler.AddressW + << ", mipLODBias = " << Sampler.MipLODBias + << ", maxAnisotropy = " << Sampler.MaxAnisotropy + << ", comparisonFunc = " << Sampler.CompFunc + << ", borderColor = " << Sampler.BorderColor + << ", minLOD = " << Sampler.MinLOD << ", maxLOD = " << Sampler.MaxLOD + << ", space = " << Sampler.Space << ", visibility = " << Sampler.Visibility + << ")"; + return OS; +} + raw_ostream &operator<<(raw_ostream &OS, const DescriptorTable &Table) { OS << "DescriptorTable(numClauses = " << Table.NumClauses << ", visibility = " << Table.Visibility << ")"; diff --git a/llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp b/llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp index 1a0c8e2a16396..831c5dd585fab 100644 --- a/llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp +++ b/llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp @@ -177,4 +177,125 @@ TEST(HLSLRootSignatureTest, AllRootFlagsDump) { EXPECT_EQ(Out, Expected); } +TEST(HLSLRootSignatureTest, RootCBVDump) { + RootDescriptor Descriptor; + Descriptor.Type = DescriptorType::CBuffer; + Descriptor.Reg = {RegisterType::BReg, 0}; + Descriptor.setDefaultFlags(); + + std::string Out; + llvm::raw_string_ostream OS(Out); + OS << Descriptor; + OS.flush(); + + std::string Expected = "RootCBV(b0, space = 0, " + "visibility = All, " + "flags = DataStaticWhileSetAtExecute)"; + EXPECT_EQ(Out, Expected); +} + +TEST(HLSLRootSignatureTest, RootSRVDump) { + RootDescriptor Descriptor; + Descriptor.Type = DescriptorType::SRV; + Descriptor.Reg = {RegisterType::TReg, 0}; + Descriptor.Space = 42; + Descriptor.Visibility = ShaderVisibility::Geometry; + Descriptor.Flags = RootDescriptorFlags::None; + + std::string Out; + llvm::raw_string_ostream OS(Out); + OS << Descriptor; + OS.flush(); + + std::string Expected = + "RootSRV(t0, space = 42, visibility = Geometry, flags = None)"; + EXPECT_EQ(Out, Expected); +} + +TEST(HLSLRootSignatureTest, RootUAVDump) { + RootDescriptor Descriptor; + Descriptor.Type = DescriptorType::UAV; + Descriptor.Reg = {RegisterType::UReg, 92374}; + Descriptor.Space = 932847; + Descriptor.Visibility = ShaderVisibility::Hull; + Descriptor.Flags = RootDescriptorFlags::ValidFlags; + + std::string Out; + llvm::raw_string_ostream OS(Out); + OS << Descriptor; + OS.flush(); + + std::string Expected = + "RootUAV(u92374, space = 932847, visibility = Hull, flags = " + "DataVolatile | " + "DataStaticWhileSetAtExecute | " + "DataStatic)"; + EXPECT_EQ(Out, Expected); +} + +TEST(HLSLRootSignatureTest, DefaultStaticSamplerDump) { + StaticSampler Sampler; + Sampler.Reg = {RegisterType::SReg, 0}; + + std::string Out; + llvm::raw_string_ostream OS(Out); + OS << Sampler; + OS.flush(); + + std::string Expected = "StaticSampler(s0, " + "filter = Anisotropic, " + "addressU = Wrap, " + "addressV = Wrap, " + "addressW = Wrap, " + "mipLODBias = 0.000000e+00, " + "maxAnisotropy = 16, " + "comparisonFunc = LessEqual, " + "borderColor = OpaqueWhite, " + "minLOD = 0.000000e+00, " + "maxLOD = 3.402823e+38, " + "space = 0, " + "visibility = All" + ")"; + EXPECT_EQ(Out, Expected); +} + +TEST(HLSLRootSignatureTest, DefinedStaticSamplerDump) { + StaticSampler Sampler; + Sampler.Reg = {RegisterType::SReg, 0}; + + Sampler.Filter = SamplerFilter::ComparisonMinMagLinearMipPoint; + Sampler.AddressU = TextureAddressMode::Mirror; + Sampler.AddressV = TextureAddressMode::Border; + Sampler.AddressW = TextureAddressMode::Clamp; + Sampler.MipLODBias = 4.8f; + Sampler.MaxAnisotropy = 32; + Sampler.CompFunc = ComparisonFunc::NotEqual; + Sampler.BorderColor = StaticBorderColor::OpaqueBlack; + Sampler.MinLOD = 1.0f; + Sampler.MaxLOD = 32.0f; + Sampler.Space = 7; + Sampler.Visibility = ShaderVisibility::Domain; + + std::string Out; + llvm::raw_string_ostream OS(Out); + OS << Sampler; + OS.flush(); + + std::string Expected = "StaticSampler(s0, " + "filter = ComparisonMinMagLinearMipPoint, " + "addressU = Mirror, " + "addressV = Border, " + "addressW = Clamp, " + "mipLODBias = 4.800000e+00, " + "maxAnisotropy = 32, " + "comparisonFunc = NotEqual, " + "borderColor = OpaqueBlack, " + "minLOD = 1.000000e+00, " + "maxLOD = 3.200000e+01, " + "space = 7, " + "visibility = Domain" + ")"; + EXPECT_EQ(Out, Expected); +} + } // namespace `````````` </details> https://github.com/llvm/llvm-project/pull/143198 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits