https://github.com/joaosaffran created https://github.com/llvm/llvm-project/pull/137284
None >From 7219ed4328aff2929f021c5efbd6901bc4bd2e20 Mon Sep 17 00:00:00 2001 From: joaosaffran <joao.saff...@microsoft.com> Date: Fri, 25 Apr 2025 05:09:08 +0000 Subject: [PATCH] refactoring mcdxbc struct to store root parameters out of order --- .../llvm/MC/DXContainerRootSignature.h | 138 +++++++++++++++++- .../include/llvm/ObjectYAML/DXContainerYAML.h | 2 +- llvm/lib/MC/DXContainerRootSignature.cpp | 78 +++++----- llvm/lib/ObjectYAML/DXContainerEmitter.cpp | 36 ++--- llvm/lib/Target/DirectX/DXILRootSignature.cpp | 45 +++--- 5 files changed, 208 insertions(+), 91 deletions(-) diff --git a/llvm/include/llvm/MC/DXContainerRootSignature.h b/llvm/include/llvm/MC/DXContainerRootSignature.h index 1f421d726bf38..e1f4abbcebf8f 100644 --- a/llvm/include/llvm/MC/DXContainerRootSignature.h +++ b/llvm/include/llvm/MC/DXContainerRootSignature.h @@ -6,22 +6,146 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/STLForwardCompat.h" #include "llvm/BinaryFormat/DXContainer.h" +#include "llvm/Support/ErrorHandling.h" +#include <cstddef> #include <cstdint> -#include <limits> +#include <variant> namespace llvm { class raw_ostream; namespace mcdxbc { +struct RootParameterHeader : public dxbc::RootParameterHeader { + + size_t Location; + + RootParameterHeader() = default; + + RootParameterHeader(dxbc::RootParameterHeader H, size_t L) + : dxbc::RootParameterHeader(H), Location(L) {} +}; + +using RootDescriptor = std::variant<dxbc::RST0::v0::RootDescriptor, + dxbc::RST0::v1::RootDescriptor>; +using ParametersView = + std::variant<dxbc::RootConstants, dxbc::RST0::v0::RootDescriptor, + dxbc::RST0::v1::RootDescriptor>; struct RootParameter { - dxbc::RootParameterHeader Header; - union { - dxbc::RootConstants Constants; - dxbc::RST0::v0::RootDescriptor Descriptor_V10; - dxbc::RST0::v1::RootDescriptor Descriptor_V11; + SmallVector<RootParameterHeader> Headers; + + SmallVector<dxbc::RootConstants> Constants; + SmallVector<RootDescriptor> Descriptors; + + void addHeader(dxbc::RootParameterHeader H, size_t L) { + Headers.push_back(RootParameterHeader(H, L)); + } + + void addParameter(dxbc::RootParameterHeader H, dxbc::RootConstants C) { + addHeader(H, Constants.size()); + Constants.push_back(C); + } + + void addParameter(dxbc::RootParameterHeader H, + dxbc::RST0::v0::RootDescriptor D) { + addHeader(H, Descriptors.size()); + Descriptors.push_back(D); + } + + void addParameter(dxbc::RootParameterHeader H, + dxbc::RST0::v1::RootDescriptor D) { + addHeader(H, Descriptors.size()); + Descriptors.push_back(D); + } + + ParametersView get(const RootParameterHeader &H) const { + switch (H.ParameterType) { + case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): + return Constants[H.Location]; + case llvm::to_underlying(dxbc::RootParameterType::CBV): + case llvm::to_underlying(dxbc::RootParameterType::SRV): + case llvm::to_underlying(dxbc::RootParameterType::UAV): + RootDescriptor VersionedParam = Descriptors[H.Location]; + if (std::holds_alternative<dxbc::RST0::v0::RootDescriptor>( + VersionedParam)) + return std::get<dxbc::RST0::v0::RootDescriptor>(VersionedParam); + return std::get<dxbc::RST0::v1::RootDescriptor>(VersionedParam); + } + + llvm_unreachable("Unimplemented parameter type"); + } + + struct iterator { + const RootParameter &Parameters; + SmallVector<RootParameterHeader>::const_iterator Current; + + // Changed parameter type to match member variable (removed const) + iterator(const RootParameter &P, + SmallVector<RootParameterHeader>::const_iterator C) + : Parameters(P), Current(C) {} + iterator(const iterator &) = default; + + ParametersView operator*() { + ParametersView Val; + switch (Current->ParameterType) { + case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): + Val = Parameters.Constants[Current->Location]; + break; + + case llvm::to_underlying(dxbc::RootParameterType::CBV): + case llvm::to_underlying(dxbc::RootParameterType::SRV): + case llvm::to_underlying(dxbc::RootParameterType::UAV): + RootDescriptor VersionedParam = + Parameters.Descriptors[Current->Location]; + if (std::holds_alternative<dxbc::RST0::v0::RootDescriptor>( + VersionedParam)) + Val = std::get<dxbc::RST0::v0::RootDescriptor>(VersionedParam); + else + Val = std::get<dxbc::RST0::v1::RootDescriptor>(VersionedParam); + break; + } + return Val; + } + + iterator operator++() { + Current++; + return *this; + } + + iterator operator++(int) { + iterator Tmp = *this; + ++*this; + return Tmp; + } + + iterator operator--() { + Current--; + return *this; + } + + iterator operator--(int) { + iterator Tmp = *this; + --*this; + return Tmp; + } + + bool operator==(const iterator I) { return I.Current == Current; } + bool operator!=(const iterator I) { return !(*this == I); } }; + + iterator begin() const { return iterator(*this, Headers.begin()); } + + iterator end() const { return iterator(*this, Headers.end()); } + + size_t size() const { return Headers.size(); } + + bool isEmpty() const { return Headers.empty(); } + + llvm::iterator_range<RootParameter::iterator> getAll() const { + return llvm::make_range(begin(), end()); + } }; struct RootSignatureDesc { @@ -30,7 +154,7 @@ struct RootSignatureDesc { uint32_t RootParameterOffset = 0U; uint32_t StaticSamplersOffset = 0u; uint32_t NumStaticSamplers = 0u; - SmallVector<mcdxbc::RootParameter> Parameters; + mcdxbc::RootParameter Parameters; void write(raw_ostream &OS) const; diff --git a/llvm/include/llvm/ObjectYAML/DXContainerYAML.h b/llvm/include/llvm/ObjectYAML/DXContainerYAML.h index c54c995acd263..e86a869da99bc 100644 --- a/llvm/include/llvm/ObjectYAML/DXContainerYAML.h +++ b/llvm/include/llvm/ObjectYAML/DXContainerYAML.h @@ -95,7 +95,7 @@ struct RootParameterYamlDesc { uint32_t Type; uint32_t Visibility; uint32_t Offset; - RootParameterYamlDesc() {}; + RootParameterYamlDesc(){}; RootParameterYamlDesc(uint32_t T) : Type(T) { switch (T) { diff --git a/llvm/lib/MC/DXContainerRootSignature.cpp b/llvm/lib/MC/DXContainerRootSignature.cpp index a5210f4768f16..18242ccc1e935 100644 --- a/llvm/lib/MC/DXContainerRootSignature.cpp +++ b/llvm/lib/MC/DXContainerRootSignature.cpp @@ -8,7 +8,9 @@ #include "llvm/MC/DXContainerRootSignature.h" #include "llvm/ADT/SmallString.h" +#include "llvm/BinaryFormat/DXContainer.h" #include "llvm/Support/EndianStream.h" +#include <variant> using namespace llvm; using namespace llvm::mcdxbc; @@ -32,22 +34,15 @@ size_t RootSignatureDesc::getSize() const { size_t Size = sizeof(dxbc::RootSignatureHeader) + Parameters.size() * sizeof(dxbc::RootParameterHeader); - for (const mcdxbc::RootParameter &P : Parameters) { - switch (P.Header.ParameterType) { - case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): - Size += sizeof(dxbc::RootConstants); - break; - case llvm::to_underlying(dxbc::RootParameterType::CBV): - case llvm::to_underlying(dxbc::RootParameterType::SRV): - case llvm::to_underlying(dxbc::RootParameterType::UAV): - if (Version == 1) - Size += sizeof(dxbc::RST0::v0::RootDescriptor); - else - Size += sizeof(dxbc::RST0::v1::RootDescriptor); - - break; - } + for (const auto &P : Parameters) { + std::visit( + [&Size](auto &Value) -> void { + using T = std::decay_t<decltype(Value)>; + Size += sizeof(T); + }, + P); } + return Size; } @@ -66,45 +61,40 @@ void RootSignatureDesc::write(raw_ostream &OS) const { support::endian::write(BOS, Flags, llvm::endianness::little); SmallVector<uint32_t> ParamsOffsets; - for (const mcdxbc::RootParameter &P : Parameters) { - support::endian::write(BOS, P.Header.ParameterType, - llvm::endianness::little); - support::endian::write(BOS, P.Header.ShaderVisibility, - llvm::endianness::little); + for (const auto &P : Parameters.Headers) { + support::endian::write(BOS, P.ParameterType, llvm::endianness::little); + support::endian::write(BOS, P.ShaderVisibility, llvm::endianness::little); ParamsOffsets.push_back(writePlaceholder(BOS)); } assert(NumParameters == ParamsOffsets.size()); - for (size_t I = 0; I < NumParameters; ++I) { + auto P = Parameters.begin(); + for (size_t I = 0; I < NumParameters; ++I, P++) { rewriteOffsetToCurrentByte(BOS, ParamsOffsets[I]); - const mcdxbc::RootParameter &P = Parameters[I]; - switch (P.Header.ParameterType) { - case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): - support::endian::write(BOS, P.Constants.ShaderRegister, + if (std::holds_alternative<dxbc::RootConstants>(*P)) { + auto Constants = std::get<dxbc::RootConstants>(*P); + support::endian::write(BOS, Constants.ShaderRegister, + llvm::endianness::little); + support::endian::write(BOS, Constants.RegisterSpace, llvm::endianness::little); - support::endian::write(BOS, P.Constants.RegisterSpace, + support::endian::write(BOS, Constants.Num32BitValues, + llvm::endianness::little); + } else if (std::holds_alternative<dxbc::RST0::v0::RootDescriptor>(*P)) { + auto Descriptor = std::get<dxbc::RST0::v0::RootDescriptor>(*P); + support::endian::write(BOS, Descriptor.ShaderRegister, + llvm::endianness::little); + support::endian::write(BOS, Descriptor.RegisterSpace, + llvm::endianness::little); + } else if (std::holds_alternative<dxbc::RST0::v1::RootDescriptor>(*P)) { + auto Descriptor = std::get<dxbc::RST0::v1::RootDescriptor>(*P); + + support::endian::write(BOS, Descriptor.ShaderRegister, llvm::endianness::little); - support::endian::write(BOS, P.Constants.Num32BitValues, + support::endian::write(BOS, Descriptor.RegisterSpace, llvm::endianness::little); - break; - case llvm::to_underlying(dxbc::RootParameterType::CBV): - case llvm::to_underlying(dxbc::RootParameterType::SRV): - case llvm::to_underlying(dxbc::RootParameterType::UAV): - if (Version == 1) { - support::endian::write(BOS, P.Descriptor_V10.ShaderRegister, - llvm::endianness::little); - support::endian::write(BOS, P.Descriptor_V10.RegisterSpace, - llvm::endianness::little); - } else { - support::endian::write(BOS, P.Descriptor_V11.ShaderRegister, - llvm::endianness::little); - support::endian::write(BOS, P.Descriptor_V11.RegisterSpace, - llvm::endianness::little); - support::endian::write(BOS, P.Descriptor_V11.Flags, - llvm::endianness::little); - } + support::endian::write(BOS, Descriptor.Flags, llvm::endianness::little); } } assert(Storage.size() == getSize()); diff --git a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp index be0e52fef04f5..8e40ff39ada36 100644 --- a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp +++ b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp @@ -274,36 +274,38 @@ void DXContainerWriter::writeParts(raw_ostream &OS) { RS.StaticSamplersOffset = P.RootSignature->StaticSamplersOffset; for (const auto &Param : P.RootSignature->Parameters) { - mcdxbc::RootParameter NewParam; - NewParam.Header = dxbc::RootParameterHeader{ - Param.Type, Param.Visibility, Param.Offset}; + auto Header = dxbc::RootParameterHeader{Param.Type, Param.Visibility, + Param.Offset}; switch (Param.Type) { case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): - NewParam.Constants.Num32BitValues = Param.Constants.Num32BitValues; - NewParam.Constants.RegisterSpace = Param.Constants.RegisterSpace; - NewParam.Constants.ShaderRegister = Param.Constants.ShaderRegister; + dxbc::RootConstants Constants; + Constants.Num32BitValues = Param.Constants.Num32BitValues; + Constants.RegisterSpace = Param.Constants.RegisterSpace; + Constants.ShaderRegister = Param.Constants.ShaderRegister; + RS.Parameters.addParameter(Header, Constants); break; case llvm::to_underlying(dxbc::RootParameterType::SRV): case llvm::to_underlying(dxbc::RootParameterType::UAV): case llvm::to_underlying(dxbc::RootParameterType::CBV): if (RS.Version == 1) { - NewParam.Descriptor_V10.RegisterSpace = - Param.Descriptor.RegisterSpace; - NewParam.Descriptor_V10.ShaderRegister = - Param.Descriptor.ShaderRegister; + dxbc::RST0::v0::RootDescriptor Descriptor; + Descriptor.RegisterSpace = Param.Descriptor.RegisterSpace; + Descriptor.ShaderRegister = Param.Descriptor.ShaderRegister; + RS.Parameters.addParameter(Header, Descriptor); } else { - NewParam.Descriptor_V11.RegisterSpace = - Param.Descriptor.RegisterSpace; - NewParam.Descriptor_V11.ShaderRegister = - Param.Descriptor.ShaderRegister; - NewParam.Descriptor_V11.Flags = Param.Descriptor.getEncodedFlags(); + dxbc::RST0::v1::RootDescriptor Descriptor; + Descriptor.RegisterSpace = Param.Descriptor.RegisterSpace; + Descriptor.ShaderRegister = Param.Descriptor.ShaderRegister; + Descriptor.Flags = Param.Descriptor.getEncodedFlags(); + RS.Parameters.addParameter(Header, Descriptor); } break; + default: + // Handling invalid parameter type edge case + RS.Parameters.addHeader(Header, -1); } - - RS.Parameters.push_back(NewParam); } RS.write(OS); diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp index ef299c17baf76..a2141aa1364ad 100644 --- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp +++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp @@ -30,6 +30,7 @@ #include <cstdint> #include <optional> #include <utility> +#include <variant> using namespace llvm; using namespace llvm::dxil; @@ -75,31 +76,32 @@ static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, if (RootConstantNode->getNumOperands() != 5) return reportError(Ctx, "Invalid format for RootConstants Element"); - mcdxbc::RootParameter NewParameter; - NewParameter.Header.ParameterType = + dxbc::RootParameterHeader Header; + Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::Constants32Bit); if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1)) - NewParameter.Header.ShaderVisibility = *Val; + Header.ShaderVisibility = *Val; else return reportError(Ctx, "Invalid value for ShaderVisibility"); + dxbc::RootConstants Constants; if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2)) - NewParameter.Constants.ShaderRegister = *Val; + Constants.ShaderRegister = *Val; else return reportError(Ctx, "Invalid value for ShaderRegister"); if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3)) - NewParameter.Constants.RegisterSpace = *Val; + Constants.RegisterSpace = *Val; else return reportError(Ctx, "Invalid value for RegisterSpace"); if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4)) - NewParameter.Constants.Num32BitValues = *Val; + Constants.Num32BitValues = *Val; else return reportError(Ctx, "Invalid value for Num32BitValues"); - RSD.Parameters.push_back(NewParameter); + RSD.Parameters.addParameter(Header, Constants); return false; } @@ -164,12 +166,11 @@ static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) { return reportValueError(Ctx, "RootFlags", RSD.Flags); } - for (const mcdxbc::RootParameter &P : RSD.Parameters) { - if (!dxbc::isValidShaderVisibility(P.Header.ShaderVisibility)) - return reportValueError(Ctx, "ShaderVisibility", - P.Header.ShaderVisibility); + for (const mcdxbc::RootParameterHeader &Header : RSD.Parameters.Headers) { + if (!dxbc::isValidShaderVisibility(Header.ShaderVisibility)) + return reportValueError(Ctx, "ShaderVisibility", Header.ShaderVisibility); - assert(dxbc::isValidParameterType(P.Header.ParameterType) && + assert(dxbc::isValidParameterType(Header.ParameterType) && "Invalid value for ParameterType"); } @@ -289,20 +290,20 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M, << "\n"; OS << indent(Space) << "NumParameters: " << RS.Parameters.size() << "\n"; Space++; - for (auto const &P : RS.Parameters) { - OS << indent(Space) << "- Parameter Type: " << P.Header.ParameterType + for (auto const &Header : RS.Parameters.Headers) { + OS << indent(Space) << "- Parameter Type: " << Header.ParameterType << "\n"; OS << indent(Space + 2) - << "Shader Visibility: " << P.Header.ShaderVisibility << "\n"; - switch (P.Header.ParameterType) { - case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): + << "Shader Visibility: " << Header.ShaderVisibility << "\n"; + mcdxbc::ParametersView P = RS.Parameters.get(Header); + if (std::holds_alternative<dxbc::RootConstants>(P)) { + auto Constants = std::get<dxbc::RootConstants>(P); + OS << indent(Space + 2) << "Register Space: " << Constants.RegisterSpace + << "\n"; OS << indent(Space + 2) - << "Register Space: " << P.Constants.RegisterSpace << "\n"; + << "Shader Register: " << Constants.ShaderRegister << "\n"; OS << indent(Space + 2) - << "Shader Register: " << P.Constants.ShaderRegister << "\n"; - OS << indent(Space + 2) - << "Num 32 Bit Values: " << P.Constants.Num32BitValues << "\n"; - break; + << "Num 32 Bit Values: " << Constants.Num32BitValues << "\n"; } } Space--; _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits