llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-hlsl

Author: Finn Plummer (inbelic)

<details>
<summary>Changes</summary>

Implements serialization of the remaining completely defined `RootElement`s, 
namely `RootDescriptor`s and `RootFlag`s.

- Adds unit testing for the serialization methods

Resolves https://github.com/llvm/llvm-project/issues/138191
Resolves https://github.com/llvm/llvm-project/issues/138193

---
Full diff: https://github.com/llvm/llvm-project/pull/143198.diff


3 Files Affected:

- (modified) llvm/include/llvm/Frontend/HLSL/HLSLRootSignatureUtils.h (+6) 
- (modified) llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp (+254) 
- (modified) llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp (+121) 


``````````diff
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignatureUtils.h 
b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignatureUtils.h
index ca20e6719f3a4..7489777670703 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignatureUtils.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignatureUtils.h
@@ -32,6 +32,12 @@ LLVM_ABI raw_ostream &operator<<(raw_ostream &OS, const 
RootFlags &Flags);
 LLVM_ABI raw_ostream &operator<<(raw_ostream &OS,
                                  const RootConstants &Constants);
 
+LLVM_ABI raw_ostream &operator<<(raw_ostream &OS,
+                                 const RootDescriptor &Descriptor);
+
+LLVM_ABI raw_ostream &operator<<(raw_ostream &OS,
+                                 const StaticSampler &StaticSampler);
+
 LLVM_ABI raw_ostream &operator<<(raw_ostream &OS,
                                  const DescriptorTableClause &Clause);
 
diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp 
b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
index 24486a55ecf6a..70c3e72c1f806 100644
--- a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
+++ b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
@@ -71,6 +71,199 @@ static raw_ostream &operator<<(raw_ostream &OS,
   return OS;
 }
 
+static raw_ostream &operator<<(raw_ostream &OS, const SamplerFilter &Filter) {
+  switch (Filter) {
+  case SamplerFilter::MinMagMipPoint:
+    OS << "MinMagMipPoint";
+    break;
+  case SamplerFilter::MinMagPointMipLinear:
+    OS << "MinMagPointMipLinear";
+    break;
+  case SamplerFilter::MinPointMagLinearMipPoint:
+    OS << "MinPointMagLinearMipPoint";
+    break;
+  case SamplerFilter::MinPointMagMipLinear:
+    OS << "MinPointMagMipLinear";
+    break;
+  case SamplerFilter::MinLinearMagMipPoint:
+    OS << "MinLinearMagMipPoint";
+    break;
+  case SamplerFilter::MinLinearMagPointMipLinear:
+    OS << "MinLinearMagPointMipLinear";
+    break;
+  case SamplerFilter::MinMagLinearMipPoint:
+    OS << "MinMagLinearMipPoint";
+    break;
+  case SamplerFilter::MinMagMipLinear:
+    OS << "MinMagMipLinear";
+    break;
+  case SamplerFilter::Anisotropic:
+    OS << "Anisotropic";
+    break;
+  case SamplerFilter::ComparisonMinMagMipPoint:
+    OS << "ComparisonMinMagMipPoint";
+    break;
+  case SamplerFilter::ComparisonMinMagPointMipLinear:
+    OS << "ComparisonMinMagPointMipLinear";
+    break;
+  case SamplerFilter::ComparisonMinPointMagLinearMipPoint:
+    OS << "ComparisonMinPointMagLinearMipPoint";
+    break;
+  case SamplerFilter::ComparisonMinPointMagMipLinear:
+    OS << "ComparisonMinPointMagMipLinear";
+    break;
+  case SamplerFilter::ComparisonMinLinearMagMipPoint:
+    OS << "ComparisonMinLinearMagMipPoint";
+    break;
+  case SamplerFilter::ComparisonMinLinearMagPointMipLinear:
+    OS << "ComparisonMinLinearMagPointMipLinear";
+    break;
+  case SamplerFilter::ComparisonMinMagLinearMipPoint:
+    OS << "ComparisonMinMagLinearMipPoint";
+    break;
+  case SamplerFilter::ComparisonMinMagMipLinear:
+    OS << "ComparisonMinMagMipLinear";
+    break;
+  case SamplerFilter::ComparisonAnisotropic:
+    OS << "ComparisonAnisotropic";
+    break;
+  case SamplerFilter::MinimumMinMagMipPoint:
+    OS << "MinimumMinMagMipPoint";
+    break;
+  case SamplerFilter::MinimumMinMagPointMipLinear:
+    OS << "MinimumMinMagPointMipLinear";
+    break;
+  case SamplerFilter::MinimumMinPointMagLinearMipPoint:
+    OS << "MinimumMinPointMagLinearMipPoint";
+    break;
+  case SamplerFilter::MinimumMinPointMagMipLinear:
+    OS << "MinimumMinPointMagMipLinear";
+    break;
+  case SamplerFilter::MinimumMinLinearMagMipPoint:
+    OS << "MinimumMinLinearMagMipPoint";
+    break;
+  case SamplerFilter::MinimumMinLinearMagPointMipLinear:
+    OS << "MinimumMinLinearMagPointMipLinear";
+    break;
+  case SamplerFilter::MinimumMinMagLinearMipPoint:
+    OS << "MinimumMinMagLinearMipPoint";
+    break;
+  case SamplerFilter::MinimumMinMagMipLinear:
+    OS << "MinimumMinMagMipLinear";
+    break;
+  case SamplerFilter::MinimumAnisotropic:
+    OS << "MinimumAnisotropic";
+    break;
+  case SamplerFilter::MaximumMinMagMipPoint:
+    OS << "MaximumMinMagMipPoint";
+    break;
+  case SamplerFilter::MaximumMinMagPointMipLinear:
+    OS << "MaximumMinMagPointMipLinear";
+    break;
+  case SamplerFilter::MaximumMinPointMagLinearMipPoint:
+    OS << "MaximumMinPointMagLinearMipPoint";
+    break;
+  case SamplerFilter::MaximumMinPointMagMipLinear:
+    OS << "MaximumMinPointMagMipLinear";
+    break;
+  case SamplerFilter::MaximumMinLinearMagMipPoint:
+    OS << "MaximumMinLinearMagMipPoint";
+    break;
+  case SamplerFilter::MaximumMinLinearMagPointMipLinear:
+    OS << "MaximumMinLinearMagPointMipLinear";
+    break;
+  case SamplerFilter::MaximumMinMagLinearMipPoint:
+    OS << "MaximumMinMagLinearMipPoint";
+    break;
+  case SamplerFilter::MaximumMinMagMipLinear:
+    OS << "MaximumMinMagMipLinear";
+    break;
+  case SamplerFilter::MaximumAnisotropic:
+    OS << "MaximumAnisotropic";
+    break;
+  }
+
+  return OS;
+}
+
+static raw_ostream &operator<<(raw_ostream &OS,
+                               const TextureAddressMode &Address) {
+  switch (Address) {
+  case TextureAddressMode::Wrap:
+    OS << "Wrap";
+    break;
+  case TextureAddressMode::Mirror:
+    OS << "Mirror";
+    break;
+  case TextureAddressMode::Clamp:
+    OS << "Clamp";
+    break;
+  case TextureAddressMode::Border:
+    OS << "Border";
+    break;
+  case TextureAddressMode::MirrorOnce:
+    OS << "MirrorOnce";
+    break;
+  }
+
+  return OS;
+}
+
+static raw_ostream &operator<<(raw_ostream &OS,
+                               const ComparisonFunc &CompFunc) {
+  switch (CompFunc) {
+  case ComparisonFunc::Never:
+    OS << "Never";
+    break;
+  case ComparisonFunc::Less:
+    OS << "Less";
+    break;
+  case ComparisonFunc::Equal:
+    OS << "Equal";
+    break;
+  case ComparisonFunc::LessEqual:
+    OS << "LessEqual";
+    break;
+  case ComparisonFunc::Greater:
+    OS << "Greater";
+    break;
+  case ComparisonFunc::NotEqual:
+    OS << "NotEqual";
+    break;
+  case ComparisonFunc::GreaterEqual:
+    OS << "GreaterEqual";
+    break;
+  case ComparisonFunc::Always:
+    OS << "Always";
+    break;
+  }
+
+  return OS;
+}
+
+static raw_ostream &operator<<(raw_ostream &OS,
+                               const StaticBorderColor &BorderColor) {
+  switch (BorderColor) {
+  case StaticBorderColor::TransparentBlack:
+    OS << "TransparentBlack";
+    break;
+  case StaticBorderColor::OpaqueBlack:
+    OS << "OpaqueBlack";
+    break;
+  case StaticBorderColor::OpaqueWhite:
+    OS << "OpaqueWhite";
+    break;
+  case StaticBorderColor::OpaqueBlackUint:
+    OS << "OpaqueBlackUint";
+    break;
+  case StaticBorderColor::OpaqueWhiteUint:
+    OS << "OpaqueWhiteUint";
+    break;
+  }
+
+  return OS;
+}
+
 static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) {
   switch (Type) {
   case ClauseType::CBuffer:
@@ -132,6 +325,42 @@ static raw_ostream &operator<<(raw_ostream &OS,
   return OS;
 }
 
+static raw_ostream &operator<<(raw_ostream &OS,
+                               const RootDescriptorFlags &Flags) {
+  bool FlagSet = false;
+  unsigned Remaining = llvm::to_underlying(Flags);
+  while (Remaining) {
+    unsigned Bit = 1u << llvm::countr_zero(Remaining);
+    if (Remaining & Bit) {
+      if (FlagSet)
+        OS << " | ";
+
+      switch (static_cast<RootDescriptorFlags>(Bit)) {
+      case RootDescriptorFlags::DataVolatile:
+        OS << "DataVolatile";
+        break;
+      case RootDescriptorFlags::DataStaticWhileSetAtExecute:
+        OS << "DataStaticWhileSetAtExecute";
+        break;
+      case RootDescriptorFlags::DataStatic:
+        OS << "DataStatic";
+        break;
+      default:
+        OS << "invalid: " << Bit;
+        break;
+      }
+
+      FlagSet = true;
+    }
+    Remaining &= ~Bit;
+  }
+
+  if (!FlagSet)
+    OS << "None";
+
+  return OS;
+}
+
 raw_ostream &operator<<(raw_ostream &OS, const RootFlags &Flags) {
   OS << "RootFlags(";
   bool FlagSet = false;
@@ -205,6 +434,31 @@ raw_ostream &operator<<(raw_ostream &OS, const 
RootConstants &Constants) {
   return OS;
 }
 
+raw_ostream &operator<<(raw_ostream &OS, const RootDescriptor &Descriptor) {
+  ClauseType Type = ClauseType(llvm::to_underlying(Descriptor.Type));
+  OS << "Root" << Type << "(" << Descriptor.Reg
+     << ", space = " << Descriptor.Space
+     << ", visibility = " << Descriptor.Visibility
+     << ", flags = " << Descriptor.Flags << ")";
+
+  return OS;
+}
+
+raw_ostream &operator<<(raw_ostream &OS, const StaticSampler &Sampler) {
+  OS << "StaticSampler(" << Sampler.Reg << ", filter = " << Sampler.Filter
+     << ", addressU = " << Sampler.AddressU
+     << ", addressV = " << Sampler.AddressV
+     << ", addressW = " << Sampler.AddressW
+     << ", mipLODBias = " << Sampler.MipLODBias
+     << ", maxAnisotropy = " << Sampler.MaxAnisotropy
+     << ", comparisonFunc = " << Sampler.CompFunc
+     << ", borderColor = " << Sampler.BorderColor
+     << ", minLOD = " << Sampler.MinLOD << ", maxLOD = " << Sampler.MaxLOD
+     << ", space = " << Sampler.Space << ", visibility = " << 
Sampler.Visibility
+     << ")";
+  return OS;
+}
+
 raw_ostream &operator<<(raw_ostream &OS, const DescriptorTable &Table) {
   OS << "DescriptorTable(numClauses = " << Table.NumClauses
      << ", visibility = " << Table.Visibility << ")";
diff --git a/llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp 
b/llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp
index 1a0c8e2a16396..831c5dd585fab 100644
--- a/llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp
+++ b/llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp
@@ -177,4 +177,125 @@ TEST(HLSLRootSignatureTest, AllRootFlagsDump) {
   EXPECT_EQ(Out, Expected);
 }
 
+TEST(HLSLRootSignatureTest, RootCBVDump) {
+  RootDescriptor Descriptor;
+  Descriptor.Type = DescriptorType::CBuffer;
+  Descriptor.Reg = {RegisterType::BReg, 0};
+  Descriptor.setDefaultFlags();
+
+  std::string Out;
+  llvm::raw_string_ostream OS(Out);
+  OS << Descriptor;
+  OS.flush();
+
+  std::string Expected = "RootCBV(b0, space = 0, "
+                         "visibility = All, "
+                         "flags = DataStaticWhileSetAtExecute)";
+  EXPECT_EQ(Out, Expected);
+}
+
+TEST(HLSLRootSignatureTest, RootSRVDump) {
+  RootDescriptor Descriptor;
+  Descriptor.Type = DescriptorType::SRV;
+  Descriptor.Reg = {RegisterType::TReg, 0};
+  Descriptor.Space = 42;
+  Descriptor.Visibility = ShaderVisibility::Geometry;
+  Descriptor.Flags = RootDescriptorFlags::None;
+
+  std::string Out;
+  llvm::raw_string_ostream OS(Out);
+  OS << Descriptor;
+  OS.flush();
+
+  std::string Expected =
+      "RootSRV(t0, space = 42, visibility = Geometry, flags = None)";
+  EXPECT_EQ(Out, Expected);
+}
+
+TEST(HLSLRootSignatureTest, RootUAVDump) {
+  RootDescriptor Descriptor;
+  Descriptor.Type = DescriptorType::UAV;
+  Descriptor.Reg = {RegisterType::UReg, 92374};
+  Descriptor.Space = 932847;
+  Descriptor.Visibility = ShaderVisibility::Hull;
+  Descriptor.Flags = RootDescriptorFlags::ValidFlags;
+
+  std::string Out;
+  llvm::raw_string_ostream OS(Out);
+  OS << Descriptor;
+  OS.flush();
+
+  std::string Expected =
+      "RootUAV(u92374, space = 932847, visibility = Hull, flags = "
+      "DataVolatile | "
+      "DataStaticWhileSetAtExecute | "
+      "DataStatic)";
+  EXPECT_EQ(Out, Expected);
+}
+
+TEST(HLSLRootSignatureTest, DefaultStaticSamplerDump) {
+  StaticSampler Sampler;
+  Sampler.Reg = {RegisterType::SReg, 0};
+
+  std::string Out;
+  llvm::raw_string_ostream OS(Out);
+  OS << Sampler;
+  OS.flush();
+
+  std::string Expected = "StaticSampler(s0, "
+                         "filter = Anisotropic, "
+                         "addressU = Wrap, "
+                         "addressV = Wrap, "
+                         "addressW = Wrap, "
+                         "mipLODBias = 0.000000e+00, "
+                         "maxAnisotropy = 16, "
+                         "comparisonFunc = LessEqual, "
+                         "borderColor = OpaqueWhite, "
+                         "minLOD = 0.000000e+00, "
+                         "maxLOD = 3.402823e+38, "
+                         "space = 0, "
+                         "visibility = All"
+                         ")";
+  EXPECT_EQ(Out, Expected);
+}
+
+TEST(HLSLRootSignatureTest, DefinedStaticSamplerDump) {
+  StaticSampler Sampler;
+  Sampler.Reg = {RegisterType::SReg, 0};
+
+  Sampler.Filter = SamplerFilter::ComparisonMinMagLinearMipPoint;
+  Sampler.AddressU = TextureAddressMode::Mirror;
+  Sampler.AddressV = TextureAddressMode::Border;
+  Sampler.AddressW = TextureAddressMode::Clamp;
+  Sampler.MipLODBias = 4.8f;
+  Sampler.MaxAnisotropy = 32;
+  Sampler.CompFunc = ComparisonFunc::NotEqual;
+  Sampler.BorderColor = StaticBorderColor::OpaqueBlack;
+  Sampler.MinLOD = 1.0f;
+  Sampler.MaxLOD = 32.0f;
+  Sampler.Space = 7;
+  Sampler.Visibility = ShaderVisibility::Domain;
+
+  std::string Out;
+  llvm::raw_string_ostream OS(Out);
+  OS << Sampler;
+  OS.flush();
+
+  std::string Expected = "StaticSampler(s0, "
+                         "filter = ComparisonMinMagLinearMipPoint, "
+                         "addressU = Mirror, "
+                         "addressV = Border, "
+                         "addressW = Clamp, "
+                         "mipLODBias = 4.800000e+00, "
+                         "maxAnisotropy = 32, "
+                         "comparisonFunc = NotEqual, "
+                         "borderColor = OpaqueBlack, "
+                         "minLOD = 1.000000e+00, "
+                         "maxLOD = 3.200000e+01, "
+                         "space = 7, "
+                         "visibility = Domain"
+                         ")";
+  EXPECT_EQ(Out, Expected);
+}
+
 } // namespace

``````````

</details>


https://github.com/llvm/llvm-project/pull/143198
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to