https://github.com/joaosaffran created https://github.com/llvm/llvm-project/pull/124967
Adding support for Root Signature Constant Element extraction and writing to DXContainer. - Adding an analysis to deal with RootSignature metadata definition - Adding validation for Constants - This PR is related to task: https://github.com/llvm/llvm-project/issues/121487 >From 039270b69929b163e381b0eedd17ba5e32f237aa Mon Sep 17 00:00:00 2001 From: joaosaffran <joao.saff...@microsoft.com> Date: Wed, 29 Jan 2025 00:01:43 +0000 Subject: [PATCH 1/3] adding root constants support to DXIL --- llvm/include/llvm/BinaryFormat/DXContainer.h | 7 +++ .../llvm/MC/DXContainerRootSignature.h | 3 +- llvm/lib/MC/DXContainerRootSignature.cpp | 4 ++ .../lib/Target/DirectX/DXContainerGlobals.cpp | 22 ++++++++ llvm/lib/Target/DirectX/DXILRootSignature.cpp | 53 +++++++++++++++++-- llvm/lib/Target/DirectX/DXILRootSignature.h | 38 +++++++++++++ .../ContainerData/RootSignature-Flags.ll | 3 +- 7 files changed, 124 insertions(+), 6 deletions(-) diff --git a/llvm/include/llvm/BinaryFormat/DXContainer.h b/llvm/include/llvm/BinaryFormat/DXContainer.h index cb57309ae183fc..50c30e509fbae6 100644 --- a/llvm/include/llvm/BinaryFormat/DXContainer.h +++ b/llvm/include/llvm/BinaryFormat/DXContainer.h @@ -86,6 +86,13 @@ struct RootConstants { }; struct RootParameter { + RootParameter() = default; + RootParameter(RootConstants RootConstant, ShaderVisibilityFlag Visibility) { + ParameterType = RootParameterType::Constants32Bit; + Constants = RootConstant; + ShaderVisibility = Visibility; + } + RootParameterType ParameterType; union { RootConstants Constants; diff --git a/llvm/include/llvm/MC/DXContainerRootSignature.h b/llvm/include/llvm/MC/DXContainerRootSignature.h index 63a5699a978b79..f0228c5c6f1bc4 100644 --- a/llvm/include/llvm/MC/DXContainerRootSignature.h +++ b/llvm/include/llvm/MC/DXContainerRootSignature.h @@ -8,7 +8,6 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/BinaryFormat/DXContainer.h" -#include <cstdint> #include <limits> namespace llvm { @@ -23,6 +22,8 @@ struct RootSignatureHeader { void swapBytes(); void write(raw_ostream &OS); + + void pushPart(dxbc::RootParameter Part); }; } // namespace mcdxbc } // namespace llvm diff --git a/llvm/lib/MC/DXContainerRootSignature.cpp b/llvm/lib/MC/DXContainerRootSignature.cpp index 5893752a767beb..c0080194ca7f5d 100644 --- a/llvm/lib/MC/DXContainerRootSignature.cpp +++ b/llvm/lib/MC/DXContainerRootSignature.cpp @@ -31,3 +31,7 @@ void RootSignatureHeader::write(raw_ostream &OS) { OS.write(reinterpret_cast<const char *>(&Param), BindingSize); } } + +void RootSignatureHeader::pushPart(dxbc::RootParameter Param){ + Parameters.push_back(Param); +} diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp index 36e7cedbdaee0c..47ac64fda7f6ef 100644 --- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp +++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp @@ -72,6 +72,23 @@ class DXContainerGlobals : public llvm::ModulePass { } // namespace +static dxbc::RootParameter constructHeaderPart(const RootSignaturePart &Part) { + + dxbc::ShaderVisibilityFlag Visibility = static_cast<dxbc::ShaderVisibilityFlag>(Part.Visibility); + + switch(Part.Type){ + + case PartType::Constants:{ + + return dxbc::RootParameter(dxbc::RootConstants { + Part.Constants.ShaderRegistry, + Part.Constants.RegistrySpace, + Part.Constants.Number32BitValues + }, Visibility); + } break; + } +} + bool DXContainerGlobals::runOnModule(Module &M) { llvm::SmallVector<GlobalValue *> Globals; Globals.push_back(getFeatureFlags(M)); @@ -163,6 +180,11 @@ void DXContainerGlobals::addRootSignature(Module &M, RootSignatureHeader RSH; RSH.Flags = MRS->Flags; RSH.Version = MRS->Version; + + for(const auto &Part : MRS->Parts){ + RSH.pushPart(constructHeaderPart(Part)); + } + RSH.write(OS); Constant *Constant = diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp index 5ee9eea68b9e60..dd0500edb3de7f 100644 --- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp +++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp @@ -14,21 +14,33 @@ #include "DXILRootSignature.h" #include "DirectX.h" #include "llvm/ADT/StringSwitch.h" -#include "llvm/BinaryFormat/DXContainer.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" #include <cassert> +#include <cstdint> using namespace llvm; using namespace llvm::dxil; + +static bool isValidShaderVisibility(uint32_t V) { + return V < static_cast<uint32_t>(ShaderVisibility::MAX_VALUE); +} + + +static uint64_t extractInt(MDNode *Node, unsigned int I) { + assert(I > 0 && I < Node->getNumOperands() && "Invalid operand Index"); + return mdconst::extract<ConstantInt>(Node->getOperand(I))->getZExtValue(); +} + + static bool parseRootFlags(ModuleRootSignature *MRS, MDNode *RootFlagNode) { assert(RootFlagNode->getNumOperands() == 2 && "Invalid format for RootFlag Element"); - auto *Flag = mdconst::extract<ConstantInt>(RootFlagNode->getOperand(1)); - auto Value = Flag->getZExtValue(); + uint64_t Value = extractInt(RootFlagNode, 1); // Root Element validation, as specified: // https://github.com/llvm/wg-hlsl/blob/main/proposals/0002-root-signature-in-clang.md#validations-during-dxil-generation @@ -38,6 +50,36 @@ static bool parseRootFlags(ModuleRootSignature *MRS, MDNode *RootFlagNode) { return false; } +static bool parseRootConstants(ModuleRootSignature *MRS, MDNode *RootFlagNode) { + assert(RootFlagNode->getNumOperands() == 5 && + "Invalid format for RootFlag Element"); + + uint32_t MaybeShaderVisibility = extractInt(RootFlagNode, 1); + assert(isValidShaderVisibility(MaybeShaderVisibility) && "Invalid shader visibility value"); + + ShaderVisibility Visibility = static_cast<ShaderVisibility>(MaybeShaderVisibility); + + uint32_t ShaderRegistry = extractInt(RootFlagNode, 2); + uint32_t RegisterSpace = extractInt(RootFlagNode, 3); + uint32_t Num32BitsValue = extractInt(RootFlagNode, 4); + + RootConstants Constant { + ShaderRegistry, + RegisterSpace, + Num32BitsValue + }; + + RootSignaturePart Part { + PartType::Constants, + {Constant}, + Visibility + }; + + MRS->pushPart(Part); + + return false; +} + static bool parseRootSignatureElement(ModuleRootSignature *MRS, MDNode *Element) { MDString *ElementText = cast<MDString>(Element->getOperand(0)); @@ -63,7 +105,10 @@ static bool parseRootSignatureElement(ModuleRootSignature *MRS, break; } - case RootSignatureElementKind::RootConstants: + case RootSignatureElementKind::RootConstants:{ + return parseRootConstants(MRS, Element); + break; + } case RootSignatureElementKind::RootDescriptor: case RootSignatureElementKind::DescriptorTable: case RootSignatureElementKind::StaticSampler: diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.h b/llvm/lib/Target/DirectX/DXILRootSignature.h index 3bbbaa12b07984..68b83aa9d8041e 100644 --- a/llvm/lib/Target/DirectX/DXILRootSignature.h +++ b/llvm/lib/Target/DirectX/DXILRootSignature.h @@ -12,9 +12,11 @@ /// //===----------------------------------------------------------------------===// +#include "llvm/ADT/SmallVector.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/PassManager.h" #include "llvm/Pass.h" +#include <cstdint> #include <optional> namespace llvm { @@ -29,13 +31,49 @@ enum class RootSignatureElementKind { StaticSampler = 5 }; +enum class PartType { + Constants = 0 +}; + +enum class ShaderVisibility : uint32_t { + SHADER_VISIBILITY_ALL = 0, + SHADER_VISIBILITY_VERTEX = 1, + SHADER_VISIBILITY_HULL = 2, + SHADER_VISIBILITY_DOMAIN = 3, + SHADER_VISIBILITY_GEOMETRY =4 , + SHADER_VISIBILITY_PIXEL = 5, + SHADER_VISIBILITY_AMPLIFICATION = 6, + SHADER_VISIBILITY_MESH = 7, + // not a flag + MAX_VALUE = 8 +}; + +struct RootConstants { + uint32_t ShaderRegistry; + uint32_t RegistrySpace; + uint32_t Number32BitValues; +}; + +struct RootSignaturePart { + PartType Type; + union { + RootConstants Constants; + }; + ShaderVisibility Visibility; +}; + struct ModuleRootSignature { uint32_t Version; uint32_t Flags; + SmallVector<RootSignaturePart> Parts; ModuleRootSignature() = default; bool parse(int32_t Version, NamedMDNode *Root); + + void pushPart(RootSignaturePart Part) { + Parts.push_back(Part); + } }; class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> { diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags.ll index 20253efbb8e5c5..b2eb1c8e47488c 100644 --- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags.ll +++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags.ll @@ -16,8 +16,9 @@ attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" } !dx.rootsignatures = !{!2} ; list of function/root signature pairs !2 = !{ ptr @main, !3 } ; function, root signature -!3 = !{ !4 } ; list of root signature elements +!3 = !{ !4, !5 } ; list of root signature elements !4 = !{ !"RootFlags", i32 1 } ; 1 = allow_input_assembler_input_layout +!5 = !{ !"RootConstants", i32 0, i32 1, i32 2, i32 3 } ; DXC: - Name: RTS0 >From 8f839fdabff2bc2cd0d38fa4e8a2d925cdf24d1e Mon Sep 17 00:00:00 2001 From: joaosaffran <joao.saff...@microsoft.com> Date: Wed, 29 Jan 2025 00:27:09 +0000 Subject: [PATCH 2/3] fixing tests --- llvm/include/llvm/BinaryFormat/DXContainer.h | 2 +- llvm/lib/MC/DXContainerRootSignature.cpp | 2 +- .../lib/Target/DirectX/DXContainerGlobals.cpp | 21 ++++----- llvm/lib/Target/DirectX/DXILRootSignature.cpp | 43 ++++++++----------- llvm/lib/Target/DirectX/DXILRootSignature.h | 20 ++++----- .../ContainerData/RootSignature-Flags.ll | 22 +++++++--- llvm/test/CodeGen/DirectX/llc-pipeline.ll | 1 + 7 files changed, 54 insertions(+), 57 deletions(-) diff --git a/llvm/include/llvm/BinaryFormat/DXContainer.h b/llvm/include/llvm/BinaryFormat/DXContainer.h index 50c30e509fbae6..25bada1fc0ef0f 100644 --- a/llvm/include/llvm/BinaryFormat/DXContainer.h +++ b/llvm/include/llvm/BinaryFormat/DXContainer.h @@ -92,7 +92,7 @@ struct RootParameter { Constants = RootConstant; ShaderVisibility = Visibility; } - + RootParameterType ParameterType; union { RootConstants Constants; diff --git a/llvm/lib/MC/DXContainerRootSignature.cpp b/llvm/lib/MC/DXContainerRootSignature.cpp index c0080194ca7f5d..60df2d3934faa4 100644 --- a/llvm/lib/MC/DXContainerRootSignature.cpp +++ b/llvm/lib/MC/DXContainerRootSignature.cpp @@ -32,6 +32,6 @@ void RootSignatureHeader::write(raw_ostream &OS) { } } -void RootSignatureHeader::pushPart(dxbc::RootParameter Param){ +void RootSignatureHeader::pushPart(dxbc::RootParameter Param) { Parameters.push_back(Param); } diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp index 47ac64fda7f6ef..3481e8a2796537 100644 --- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp +++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp @@ -73,18 +73,19 @@ class DXContainerGlobals : public llvm::ModulePass { } // namespace static dxbc::RootParameter constructHeaderPart(const RootSignaturePart &Part) { - - dxbc::ShaderVisibilityFlag Visibility = static_cast<dxbc::ShaderVisibilityFlag>(Part.Visibility); - switch(Part.Type){ + dxbc::ShaderVisibilityFlag Visibility = + static_cast<dxbc::ShaderVisibilityFlag>(Part.Visibility); - case PartType::Constants:{ + switch (Part.Type) { - return dxbc::RootParameter(dxbc::RootConstants { - Part.Constants.ShaderRegistry, - Part.Constants.RegistrySpace, - Part.Constants.Number32BitValues - }, Visibility); + case PartType::Constants: { + + return dxbc::RootParameter( + dxbc::RootConstants{Part.Constants.ShaderRegistry, + Part.Constants.RegistrySpace, + Part.Constants.Number32BitValues}, + Visibility); } break; } } @@ -181,7 +182,7 @@ void DXContainerGlobals::addRootSignature(Module &M, RSH.Flags = MRS->Flags; RSH.Version = MRS->Version; - for(const auto &Part : MRS->Parts){ + for (const auto &Part : MRS->Parts) { RSH.pushPart(constructHeaderPart(Part)); } diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp index dd0500edb3de7f..8d4ec26d4ce806 100644 --- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp +++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp @@ -24,18 +24,15 @@ using namespace llvm; using namespace llvm::dxil; - static bool isValidShaderVisibility(uint32_t V) { - return V < static_cast<uint32_t>(ShaderVisibility::MAX_VALUE); + return V < static_cast<uint32_t>(ShaderVisibility::MAX_VALUE); } - static uint64_t extractInt(MDNode *Node, unsigned int I) { assert(I > 0 && I < Node->getNumOperands() && "Invalid operand Index"); return mdconst::extract<ConstantInt>(Node->getOperand(I))->getZExtValue(); } - static bool parseRootFlags(ModuleRootSignature *MRS, MDNode *RootFlagNode) { assert(RootFlagNode->getNumOperands() == 2 && @@ -53,27 +50,21 @@ static bool parseRootFlags(ModuleRootSignature *MRS, MDNode *RootFlagNode) { static bool parseRootConstants(ModuleRootSignature *MRS, MDNode *RootFlagNode) { assert(RootFlagNode->getNumOperands() == 5 && "Invalid format for RootFlag Element"); - - uint32_t MaybeShaderVisibility = extractInt(RootFlagNode, 1); - assert(isValidShaderVisibility(MaybeShaderVisibility) && "Invalid shader visibility value"); - - ShaderVisibility Visibility = static_cast<ShaderVisibility>(MaybeShaderVisibility); - - uint32_t ShaderRegistry = extractInt(RootFlagNode, 2); - uint32_t RegisterSpace = extractInt(RootFlagNode, 3); - uint32_t Num32BitsValue = extractInt(RootFlagNode, 4); - - RootConstants Constant { - ShaderRegistry, - RegisterSpace, - Num32BitsValue - }; - - RootSignaturePart Part { - PartType::Constants, - {Constant}, - Visibility - }; + + uint32_t MaybeShaderVisibility = extractInt(RootFlagNode, 1); + assert(isValidShaderVisibility(MaybeShaderVisibility) && + "Invalid shader visibility value"); + + ShaderVisibility Visibility = + static_cast<ShaderVisibility>(MaybeShaderVisibility); + + uint32_t ShaderRegistry = extractInt(RootFlagNode, 2); + uint32_t RegisterSpace = extractInt(RootFlagNode, 3); + uint32_t Num32BitsValue = extractInt(RootFlagNode, 4); + + RootConstants Constant{ShaderRegistry, RegisterSpace, Num32BitsValue}; + + RootSignaturePart Part{PartType::Constants, {Constant}, Visibility}; MRS->pushPart(Part); @@ -105,7 +96,7 @@ static bool parseRootSignatureElement(ModuleRootSignature *MRS, break; } - case RootSignatureElementKind::RootConstants:{ + case RootSignatureElementKind::RootConstants: { return parseRootConstants(MRS, Element); break; } diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.h b/llvm/lib/Target/DirectX/DXILRootSignature.h index 68b83aa9d8041e..019a980eadff56 100644 --- a/llvm/lib/Target/DirectX/DXILRootSignature.h +++ b/llvm/lib/Target/DirectX/DXILRootSignature.h @@ -31,16 +31,14 @@ enum class RootSignatureElementKind { StaticSampler = 5 }; -enum class PartType { - Constants = 0 -}; +enum class PartType { Constants = 0 }; enum class ShaderVisibility : uint32_t { SHADER_VISIBILITY_ALL = 0, SHADER_VISIBILITY_VERTEX = 1, SHADER_VISIBILITY_HULL = 2, SHADER_VISIBILITY_DOMAIN = 3, - SHADER_VISIBILITY_GEOMETRY =4 , + SHADER_VISIBILITY_GEOMETRY = 4, SHADER_VISIBILITY_PIXEL = 5, SHADER_VISIBILITY_AMPLIFICATION = 6, SHADER_VISIBILITY_MESH = 7, @@ -55,11 +53,11 @@ struct RootConstants { }; struct RootSignaturePart { - PartType Type; - union { - RootConstants Constants; - }; - ShaderVisibility Visibility; + PartType Type; + union { + RootConstants Constants; + }; + ShaderVisibility Visibility; }; struct ModuleRootSignature { @@ -71,9 +69,7 @@ struct ModuleRootSignature { bool parse(int32_t Version, NamedMDNode *Root); - void pushPart(RootSignaturePart Part) { - Parts.push_back(Part); - } + void pushPart(RootSignaturePart Part) { Parts.push_back(Part); } }; class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> { diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags.ll index b2eb1c8e47488c..6ee70219b8bd30 100644 --- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags.ll +++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags.ll @@ -3,7 +3,7 @@ target triple = "dxil-unknown-shadermodel6.0-compute" -; CHECK: @dx.rts0 = private constant [12 x i8] c"{{.*}}", section "RTS0", align 4 +; CHECK: @dx.rts0 = private constant [40 x i8] c"{{.*}}", section "RTS0", align 4 define void @main() #0 { @@ -21,9 +21,17 @@ attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" } !5 = !{ !"RootConstants", i32 0, i32 1, i32 2, i32 3 } -; DXC: - Name: RTS0 -; DXC-NEXT: Size: 12 -; DXC-NEXT: RootSignature: -; DXC-NEXT: Size: 8 -; DXC-NEXT: Version: 1 -; DXC-NEXT: AllowInputAssemblerInputLayout: true +; DXC: - Name: RTS0 +; DXC-NEXT: Size: 40 +; DXC-NEXT: RootSignature: +; DXC-NEXT: Size: 64 +; DXC-NEXT: Version: 1 +; DXC-NEXT: NumParameters: 1 +; DXC-NEXT: Parameters: +; DXC-NEXT: - Type: Constants32Bit +; DXC-NEXT: ShaderVisibility: All +; DXC-NEXT: Constants: +; DXC-NEXT: Num32BitValues: 3 +; DXC-NEXT: ShaderRegister: 1 +; DXC-NEXT: RegisterSpace: 2 +; DXC-NEXT: AllowInputAssemblerInputLayout: true diff --git a/llvm/test/CodeGen/DirectX/llc-pipeline.ll b/llvm/test/CodeGen/DirectX/llc-pipeline.ll index b0715572494146..fc0a7833ea2f07 100644 --- a/llvm/test/CodeGen/DirectX/llc-pipeline.ll +++ b/llvm/test/CodeGen/DirectX/llc-pipeline.ll @@ -33,6 +33,7 @@ ; CHECK-ASM-NEXT: Print Module IR ; CHECK-OBJ-NEXT: DXIL Embedder +; CHECK-OBJ-NEXT: DXIL Root Signature Analysis ; CHECK-OBJ-NEXT: DXContainer Global Emitter ; CHECK-OBJ-NEXT: FunctionPass Manager ; CHECK-OBJ-NEXT: Lazy Machine Block Frequency Analysis >From 7fdf7822c3ab03dda84198bcc1684dbc4a7ac4bd Mon Sep 17 00:00:00 2001 From: joaosaffran <joao.saff...@microsoft.com> Date: Wed, 29 Jan 2025 18:19:20 +0000 Subject: [PATCH 3/3] clean up --- llvm/lib/Target/DirectX/DXILRootSignature.cpp | 3 --- llvm/lib/Target/DirectX/DXILRootSignature.h | 1 - 2 files changed, 4 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp index 8d4ec26d4ce806..8cf3895a0c36bf 100644 --- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp +++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp @@ -17,9 +17,6 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" -#include "llvm/IR/Type.h" -#include <cassert> -#include <cstdint> using namespace llvm; using namespace llvm::dxil; diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.h b/llvm/lib/Target/DirectX/DXILRootSignature.h index 019a980eadff56..681d7a7cd0ecdb 100644 --- a/llvm/lib/Target/DirectX/DXILRootSignature.h +++ b/llvm/lib/Target/DirectX/DXILRootSignature.h @@ -16,7 +16,6 @@ #include "llvm/IR/Metadata.h" #include "llvm/IR/PassManager.h" #include "llvm/Pass.h" -#include <cstdint> #include <optional> namespace llvm { _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits