https://github.com/joaosaffran created 
https://github.com/llvm/llvm-project/pull/137284

None

>From 7219ed4328aff2929f021c5efbd6901bc4bd2e20 Mon Sep 17 00:00:00 2001
From: joaosaffran <joao.saff...@microsoft.com>
Date: Fri, 25 Apr 2025 05:09:08 +0000
Subject: [PATCH] refactoring mcdxbc struct to store root parameters out of
 order

---
 .../llvm/MC/DXContainerRootSignature.h        | 138 +++++++++++++++++-
 .../include/llvm/ObjectYAML/DXContainerYAML.h |   2 +-
 llvm/lib/MC/DXContainerRootSignature.cpp      |  78 +++++-----
 llvm/lib/ObjectYAML/DXContainerEmitter.cpp    |  36 ++---
 llvm/lib/Target/DirectX/DXILRootSignature.cpp |  45 +++---
 5 files changed, 208 insertions(+), 91 deletions(-)

diff --git a/llvm/include/llvm/MC/DXContainerRootSignature.h 
b/llvm/include/llvm/MC/DXContainerRootSignature.h
index 1f421d726bf38..e1f4abbcebf8f 100644
--- a/llvm/include/llvm/MC/DXContainerRootSignature.h
+++ b/llvm/include/llvm/MC/DXContainerRootSignature.h
@@ -6,22 +6,146 @@
 //
 
//===----------------------------------------------------------------------===//
 
+#include "llvm/ADT/STLForwardCompat.h"
 #include "llvm/BinaryFormat/DXContainer.h"
+#include "llvm/Support/ErrorHandling.h"
+#include <cstddef>
 #include <cstdint>
-#include <limits>
+#include <variant>
 
 namespace llvm {
 
 class raw_ostream;
 namespace mcdxbc {
 
+struct RootParameterHeader : public dxbc::RootParameterHeader {
+
+  size_t Location;
+
+  RootParameterHeader() = default;
+
+  RootParameterHeader(dxbc::RootParameterHeader H, size_t L)
+      : dxbc::RootParameterHeader(H), Location(L) {}
+};
+
+using RootDescriptor = std::variant<dxbc::RST0::v0::RootDescriptor,
+                                    dxbc::RST0::v1::RootDescriptor>;
+using ParametersView =
+    std::variant<dxbc::RootConstants, dxbc::RST0::v0::RootDescriptor,
+                 dxbc::RST0::v1::RootDescriptor>;
 struct RootParameter {
-  dxbc::RootParameterHeader Header;
-  union {
-    dxbc::RootConstants Constants;
-    dxbc::RST0::v0::RootDescriptor Descriptor_V10;
-    dxbc::RST0::v1::RootDescriptor Descriptor_V11;
+  SmallVector<RootParameterHeader> Headers;
+
+  SmallVector<dxbc::RootConstants> Constants;
+  SmallVector<RootDescriptor> Descriptors;
+
+  void addHeader(dxbc::RootParameterHeader H, size_t L) {
+    Headers.push_back(RootParameterHeader(H, L));
+  }
+
+  void addParameter(dxbc::RootParameterHeader H, dxbc::RootConstants C) {
+    addHeader(H, Constants.size());
+    Constants.push_back(C);
+  }
+
+  void addParameter(dxbc::RootParameterHeader H,
+                    dxbc::RST0::v0::RootDescriptor D) {
+    addHeader(H, Descriptors.size());
+    Descriptors.push_back(D);
+  }
+
+  void addParameter(dxbc::RootParameterHeader H,
+                    dxbc::RST0::v1::RootDescriptor D) {
+    addHeader(H, Descriptors.size());
+    Descriptors.push_back(D);
+  }
+
+  ParametersView get(const RootParameterHeader &H) const {
+    switch (H.ParameterType) {
+    case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
+      return Constants[H.Location];
+    case llvm::to_underlying(dxbc::RootParameterType::CBV):
+    case llvm::to_underlying(dxbc::RootParameterType::SRV):
+    case llvm::to_underlying(dxbc::RootParameterType::UAV):
+      RootDescriptor VersionedParam = Descriptors[H.Location];
+      if (std::holds_alternative<dxbc::RST0::v0::RootDescriptor>(
+              VersionedParam))
+        return std::get<dxbc::RST0::v0::RootDescriptor>(VersionedParam);
+      return std::get<dxbc::RST0::v1::RootDescriptor>(VersionedParam);
+    }
+
+    llvm_unreachable("Unimplemented parameter type");
+  }
+
+  struct iterator {
+    const RootParameter &Parameters;
+    SmallVector<RootParameterHeader>::const_iterator Current;
+
+    // Changed parameter type to match member variable (removed const)
+    iterator(const RootParameter &P,
+             SmallVector<RootParameterHeader>::const_iterator C)
+        : Parameters(P), Current(C) {}
+    iterator(const iterator &) = default;
+
+    ParametersView operator*() {
+      ParametersView Val;
+      switch (Current->ParameterType) {
+      case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
+        Val = Parameters.Constants[Current->Location];
+        break;
+
+      case llvm::to_underlying(dxbc::RootParameterType::CBV):
+      case llvm::to_underlying(dxbc::RootParameterType::SRV):
+      case llvm::to_underlying(dxbc::RootParameterType::UAV):
+        RootDescriptor VersionedParam =
+            Parameters.Descriptors[Current->Location];
+        if (std::holds_alternative<dxbc::RST0::v0::RootDescriptor>(
+                VersionedParam))
+          Val = std::get<dxbc::RST0::v0::RootDescriptor>(VersionedParam);
+        else
+          Val = std::get<dxbc::RST0::v1::RootDescriptor>(VersionedParam);
+        break;
+      }
+      return Val;
+    }
+
+    iterator operator++() {
+      Current++;
+      return *this;
+    }
+
+    iterator operator++(int) {
+      iterator Tmp = *this;
+      ++*this;
+      return Tmp;
+    }
+
+    iterator operator--() {
+      Current--;
+      return *this;
+    }
+
+    iterator operator--(int) {
+      iterator Tmp = *this;
+      --*this;
+      return Tmp;
+    }
+
+    bool operator==(const iterator I) { return I.Current == Current; }
+    bool operator!=(const iterator I) { return !(*this == I); }
   };
+
+  iterator begin() const { return iterator(*this, Headers.begin()); }
+
+  iterator end() const { return iterator(*this, Headers.end()); }
+
+  size_t size() const { return Headers.size(); }
+
+  bool isEmpty() const { return Headers.empty(); }
+
+  llvm::iterator_range<RootParameter::iterator> getAll() const {
+    return llvm::make_range(begin(), end());
+  }
 };
 struct RootSignatureDesc {
 
@@ -30,7 +154,7 @@ struct RootSignatureDesc {
   uint32_t RootParameterOffset = 0U;
   uint32_t StaticSamplersOffset = 0u;
   uint32_t NumStaticSamplers = 0u;
-  SmallVector<mcdxbc::RootParameter> Parameters;
+  mcdxbc::RootParameter Parameters;
 
   void write(raw_ostream &OS) const;
 
diff --git a/llvm/include/llvm/ObjectYAML/DXContainerYAML.h 
b/llvm/include/llvm/ObjectYAML/DXContainerYAML.h
index c54c995acd263..e86a869da99bc 100644
--- a/llvm/include/llvm/ObjectYAML/DXContainerYAML.h
+++ b/llvm/include/llvm/ObjectYAML/DXContainerYAML.h
@@ -95,7 +95,7 @@ struct RootParameterYamlDesc {
   uint32_t Type;
   uint32_t Visibility;
   uint32_t Offset;
-  RootParameterYamlDesc() {};
+  RootParameterYamlDesc(){};
   RootParameterYamlDesc(uint32_t T) : Type(T) {
     switch (T) {
 
diff --git a/llvm/lib/MC/DXContainerRootSignature.cpp 
b/llvm/lib/MC/DXContainerRootSignature.cpp
index a5210f4768f16..18242ccc1e935 100644
--- a/llvm/lib/MC/DXContainerRootSignature.cpp
+++ b/llvm/lib/MC/DXContainerRootSignature.cpp
@@ -8,7 +8,9 @@
 
 #include "llvm/MC/DXContainerRootSignature.h"
 #include "llvm/ADT/SmallString.h"
+#include "llvm/BinaryFormat/DXContainer.h"
 #include "llvm/Support/EndianStream.h"
+#include <variant>
 
 using namespace llvm;
 using namespace llvm::mcdxbc;
@@ -32,22 +34,15 @@ size_t RootSignatureDesc::getSize() const {
   size_t Size = sizeof(dxbc::RootSignatureHeader) +
                 Parameters.size() * sizeof(dxbc::RootParameterHeader);
 
-  for (const mcdxbc::RootParameter &P : Parameters) {
-    switch (P.Header.ParameterType) {
-    case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
-      Size += sizeof(dxbc::RootConstants);
-      break;
-    case llvm::to_underlying(dxbc::RootParameterType::CBV):
-    case llvm::to_underlying(dxbc::RootParameterType::SRV):
-    case llvm::to_underlying(dxbc::RootParameterType::UAV):
-      if (Version == 1)
-        Size += sizeof(dxbc::RST0::v0::RootDescriptor);
-      else
-        Size += sizeof(dxbc::RST0::v1::RootDescriptor);
-
-      break;
-    }
+  for (const auto &P : Parameters) {
+    std::visit(
+        [&Size](auto &Value) -> void {
+          using T = std::decay_t<decltype(Value)>;
+          Size += sizeof(T);
+        },
+        P);
   }
+
   return Size;
 }
 
@@ -66,45 +61,40 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
   support::endian::write(BOS, Flags, llvm::endianness::little);
 
   SmallVector<uint32_t> ParamsOffsets;
-  for (const mcdxbc::RootParameter &P : Parameters) {
-    support::endian::write(BOS, P.Header.ParameterType,
-                           llvm::endianness::little);
-    support::endian::write(BOS, P.Header.ShaderVisibility,
-                           llvm::endianness::little);
+  for (const auto &P : Parameters.Headers) {
+    support::endian::write(BOS, P.ParameterType, llvm::endianness::little);
+    support::endian::write(BOS, P.ShaderVisibility, llvm::endianness::little);
 
     ParamsOffsets.push_back(writePlaceholder(BOS));
   }
 
   assert(NumParameters == ParamsOffsets.size());
-  for (size_t I = 0; I < NumParameters; ++I) {
+  auto P = Parameters.begin();
+  for (size_t I = 0; I < NumParameters; ++I, P++) {
     rewriteOffsetToCurrentByte(BOS, ParamsOffsets[I]);
-    const mcdxbc::RootParameter &P = Parameters[I];
 
-    switch (P.Header.ParameterType) {
-    case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
-      support::endian::write(BOS, P.Constants.ShaderRegister,
+    if (std::holds_alternative<dxbc::RootConstants>(*P)) {
+      auto Constants = std::get<dxbc::RootConstants>(*P);
+      support::endian::write(BOS, Constants.ShaderRegister,
+                             llvm::endianness::little);
+      support::endian::write(BOS, Constants.RegisterSpace,
                              llvm::endianness::little);
-      support::endian::write(BOS, P.Constants.RegisterSpace,
+      support::endian::write(BOS, Constants.Num32BitValues,
+                             llvm::endianness::little);
+    } else if (std::holds_alternative<dxbc::RST0::v0::RootDescriptor>(*P)) {
+      auto Descriptor = std::get<dxbc::RST0::v0::RootDescriptor>(*P);
+      support::endian::write(BOS, Descriptor.ShaderRegister,
+                             llvm::endianness::little);
+      support::endian::write(BOS, Descriptor.RegisterSpace,
+                             llvm::endianness::little);
+    } else if (std::holds_alternative<dxbc::RST0::v1::RootDescriptor>(*P)) {
+      auto Descriptor = std::get<dxbc::RST0::v1::RootDescriptor>(*P);
+
+      support::endian::write(BOS, Descriptor.ShaderRegister,
                              llvm::endianness::little);
-      support::endian::write(BOS, P.Constants.Num32BitValues,
+      support::endian::write(BOS, Descriptor.RegisterSpace,
                              llvm::endianness::little);
-      break;
-    case llvm::to_underlying(dxbc::RootParameterType::CBV):
-    case llvm::to_underlying(dxbc::RootParameterType::SRV):
-    case llvm::to_underlying(dxbc::RootParameterType::UAV):
-      if (Version == 1) {
-        support::endian::write(BOS, P.Descriptor_V10.ShaderRegister,
-                               llvm::endianness::little);
-        support::endian::write(BOS, P.Descriptor_V10.RegisterSpace,
-                               llvm::endianness::little);
-      } else {
-        support::endian::write(BOS, P.Descriptor_V11.ShaderRegister,
-                               llvm::endianness::little);
-        support::endian::write(BOS, P.Descriptor_V11.RegisterSpace,
-                               llvm::endianness::little);
-        support::endian::write(BOS, P.Descriptor_V11.Flags,
-                               llvm::endianness::little);
-      }
+      support::endian::write(BOS, Descriptor.Flags, llvm::endianness::little);
     }
   }
   assert(Storage.size() == getSize());
diff --git a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp 
b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
index be0e52fef04f5..8e40ff39ada36 100644
--- a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
+++ b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
@@ -274,36 +274,38 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
       RS.StaticSamplersOffset = P.RootSignature->StaticSamplersOffset;
 
       for (const auto &Param : P.RootSignature->Parameters) {
-        mcdxbc::RootParameter NewParam;
-        NewParam.Header = dxbc::RootParameterHeader{
-            Param.Type, Param.Visibility, Param.Offset};
+        auto Header = dxbc::RootParameterHeader{Param.Type, Param.Visibility,
+                                                Param.Offset};
 
         switch (Param.Type) {
         case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
-          NewParam.Constants.Num32BitValues = Param.Constants.Num32BitValues;
-          NewParam.Constants.RegisterSpace = Param.Constants.RegisterSpace;
-          NewParam.Constants.ShaderRegister = Param.Constants.ShaderRegister;
+          dxbc::RootConstants Constants;
+          Constants.Num32BitValues = Param.Constants.Num32BitValues;
+          Constants.RegisterSpace = Param.Constants.RegisterSpace;
+          Constants.ShaderRegister = Param.Constants.ShaderRegister;
+          RS.Parameters.addParameter(Header, Constants);
           break;
         case llvm::to_underlying(dxbc::RootParameterType::SRV):
         case llvm::to_underlying(dxbc::RootParameterType::UAV):
         case llvm::to_underlying(dxbc::RootParameterType::CBV):
           if (RS.Version == 1) {
-            NewParam.Descriptor_V10.RegisterSpace =
-                Param.Descriptor.RegisterSpace;
-            NewParam.Descriptor_V10.ShaderRegister =
-                Param.Descriptor.ShaderRegister;
+            dxbc::RST0::v0::RootDescriptor Descriptor;
+            Descriptor.RegisterSpace = Param.Descriptor.RegisterSpace;
+            Descriptor.ShaderRegister = Param.Descriptor.ShaderRegister;
+            RS.Parameters.addParameter(Header, Descriptor);
           } else {
-            NewParam.Descriptor_V11.RegisterSpace =
-                Param.Descriptor.RegisterSpace;
-            NewParam.Descriptor_V11.ShaderRegister =
-                Param.Descriptor.ShaderRegister;
-            NewParam.Descriptor_V11.Flags = Param.Descriptor.getEncodedFlags();
+            dxbc::RST0::v1::RootDescriptor Descriptor;
+            Descriptor.RegisterSpace = Param.Descriptor.RegisterSpace;
+            Descriptor.ShaderRegister = Param.Descriptor.ShaderRegister;
+            Descriptor.Flags = Param.Descriptor.getEncodedFlags();
+            RS.Parameters.addParameter(Header, Descriptor);
           }
 
           break;
+        default:
+          // Handling invalid parameter type edge case
+          RS.Parameters.addHeader(Header, -1);
         }
-
-        RS.Parameters.push_back(NewParam);
       }
 
       RS.write(OS);
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp 
b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
index ef299c17baf76..a2141aa1364ad 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
@@ -30,6 +30,7 @@
 #include <cstdint>
 #include <optional>
 #include <utility>
+#include <variant>
 
 using namespace llvm;
 using namespace llvm::dxil;
@@ -75,31 +76,32 @@ static bool parseRootConstants(LLVMContext *Ctx, 
mcdxbc::RootSignatureDesc &RSD,
   if (RootConstantNode->getNumOperands() != 5)
     return reportError(Ctx, "Invalid format for RootConstants Element");
 
-  mcdxbc::RootParameter NewParameter;
-  NewParameter.Header.ParameterType =
+  dxbc::RootParameterHeader Header;
+  Header.ParameterType =
       llvm::to_underlying(dxbc::RootParameterType::Constants32Bit);
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
-    NewParameter.Header.ShaderVisibility = *Val;
+    Header.ShaderVisibility = *Val;
   else
     return reportError(Ctx, "Invalid value for ShaderVisibility");
 
+  dxbc::RootConstants Constants;
   if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
-    NewParameter.Constants.ShaderRegister = *Val;
+    Constants.ShaderRegister = *Val;
   else
     return reportError(Ctx, "Invalid value for ShaderRegister");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
-    NewParameter.Constants.RegisterSpace = *Val;
+    Constants.RegisterSpace = *Val;
   else
     return reportError(Ctx, "Invalid value for RegisterSpace");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
-    NewParameter.Constants.Num32BitValues = *Val;
+    Constants.Num32BitValues = *Val;
   else
     return reportError(Ctx, "Invalid value for Num32BitValues");
 
-  RSD.Parameters.push_back(NewParameter);
+  RSD.Parameters.addParameter(Header, Constants);
 
   return false;
 }
@@ -164,12 +166,11 @@ static bool validate(LLVMContext *Ctx, const 
mcdxbc::RootSignatureDesc &RSD) {
     return reportValueError(Ctx, "RootFlags", RSD.Flags);
   }
 
-  for (const mcdxbc::RootParameter &P : RSD.Parameters) {
-    if (!dxbc::isValidShaderVisibility(P.Header.ShaderVisibility))
-      return reportValueError(Ctx, "ShaderVisibility",
-                              P.Header.ShaderVisibility);
+  for (const mcdxbc::RootParameterHeader &Header : RSD.Parameters.Headers) {
+    if (!dxbc::isValidShaderVisibility(Header.ShaderVisibility))
+      return reportValueError(Ctx, "ShaderVisibility", 
Header.ShaderVisibility);
 
-    assert(dxbc::isValidParameterType(P.Header.ParameterType) &&
+    assert(dxbc::isValidParameterType(Header.ParameterType) &&
            "Invalid value for ParameterType");
   }
 
@@ -289,20 +290,20 @@ PreservedAnalyses 
RootSignatureAnalysisPrinter::run(Module &M,
        << "\n";
     OS << indent(Space) << "NumParameters: " << RS.Parameters.size() << "\n";
     Space++;
-    for (auto const &P : RS.Parameters) {
-      OS << indent(Space) << "- Parameter Type: " << P.Header.ParameterType
+    for (auto const &Header : RS.Parameters.Headers) {
+      OS << indent(Space) << "- Parameter Type: " << Header.ParameterType
          << "\n";
       OS << indent(Space + 2)
-         << "Shader Visibility: " << P.Header.ShaderVisibility << "\n";
-      switch (P.Header.ParameterType) {
-      case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
+         << "Shader Visibility: " << Header.ShaderVisibility << "\n";
+      mcdxbc::ParametersView P = RS.Parameters.get(Header);
+      if (std::holds_alternative<dxbc::RootConstants>(P)) {
+        auto Constants = std::get<dxbc::RootConstants>(P);
+        OS << indent(Space + 2) << "Register Space: " << 
Constants.RegisterSpace
+           << "\n";
         OS << indent(Space + 2)
-           << "Register Space: " << P.Constants.RegisterSpace << "\n";
+           << "Shader Register: " << Constants.ShaderRegister << "\n";
         OS << indent(Space + 2)
-           << "Shader Register: " << P.Constants.ShaderRegister << "\n";
-        OS << indent(Space + 2)
-           << "Num 32 Bit Values: " << P.Constants.Num32BitValues << "\n";
-        break;
+           << "Num 32 Bit Values: " << Constants.Num32BitValues << "\n";
       }
     }
     Space--;

_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to