llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clang Author: Finn Plummer (inbelic) <details> <summary>Changes</summary> - defines RootDescriptorFlags in-memory representation - defines parseRootDescriptorFlags to be DXC compatible. This is why we support multiple `|` flags even validation will assert that only one flag is set... - add unit tests to demonstrate functionality Final part of and resolves https://github.com/llvm/llvm-project/issues/126577 --- Full diff: https://github.com/llvm/llvm-project/pull/140152.diff 4 Files Affected: - (modified) clang/include/clang/Parse/ParseHLSLRootSignature.h (+3) - (modified) clang/lib/Parse/ParseHLSLRootSignature.cpp (+60) - (modified) clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp (+17-3) - (modified) llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h (+25) ``````````diff diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h index 436d217cec5b1..7b9168290d62a 100644 --- a/clang/include/clang/Parse/ParseHLSLRootSignature.h +++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h @@ -93,6 +93,7 @@ class RootSignatureParser { std::optional<llvm::hlsl::rootsig::Register> Reg; std::optional<uint32_t> Space; std::optional<llvm::hlsl::rootsig::ShaderVisibility> Visibility; + std::optional<llvm::hlsl::rootsig::RootDescriptorFlags> Flags; }; std::optional<ParsedRootParamParams> parseRootParamParams(RootSignatureToken::Kind RegType); @@ -113,6 +114,8 @@ class RootSignatureParser { /// Parsing methods of various enums std::optional<llvm::hlsl::rootsig::ShaderVisibility> parseShaderVisibility(); + std::optional<llvm::hlsl::rootsig::RootDescriptorFlags> + parseRootDescriptorFlags(); std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags> parseDescriptorRangeFlags(); diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp index edb61f29f10d7..faf261cc9b7fe 100644 --- a/clang/lib/Parse/ParseHLSLRootSignature.cpp +++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp @@ -193,6 +193,7 @@ std::optional<RootParam> RootSignatureParser::parseRootParam() { ExpectedReg = TokenKind::uReg; break; } + Param.setDefaultFlags(); auto Params = parseRootParamParams(ExpectedReg); if (!Params.has_value()) @@ -214,6 +215,9 @@ std::optional<RootParam> RootSignatureParser::parseRootParam() { if (Params->Visibility.has_value()) Param.Visibility = Params->Visibility.value(); + if (Params->Flags.has_value()) + Param.Flags = Params->Flags.value(); + if (consumeExpectedToken(TokenKind::pu_r_paren, diag::err_hlsl_unexpected_end_of_params, /*param of=*/TokenKind::kw_RootConstants)) @@ -475,6 +479,23 @@ RootSignatureParser::parseRootParamParams(TokenKind RegType) { return std::nullopt; Params.Visibility = Visibility; } + + // `flags` `=` ROOT_DESCRIPTOR_FLAGS + if (tryConsumeExpectedToken(TokenKind::kw_flags)) { + if (Params.Flags.has_value()) { + getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param) + << CurToken.TokKind; + return std::nullopt; + } + + if (consumeExpectedToken(TokenKind::pu_equal)) + return std::nullopt; + + auto Flags = parseRootDescriptorFlags(); + if (!Flags.has_value()) + return std::nullopt; + Params.Flags = Flags; + } } while (tryConsumeExpectedToken(TokenKind::pu_comma)); return Params; @@ -654,6 +675,45 @@ RootSignatureParser::parseShaderVisibility() { return std::nullopt; } +std::optional<llvm::hlsl::rootsig::RootDescriptorFlags> +RootSignatureParser::parseRootDescriptorFlags() { + assert(CurToken.TokKind == TokenKind::pu_equal && + "Expects to only be invoked starting at given keyword"); + + // Handle the edge-case of '0' to specify no flags set + if (tryConsumeExpectedToken(TokenKind::int_literal)) { + if (!verifyZeroFlag()) { + getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_non_zero_flag); + return std::nullopt; + } + return RootDescriptorFlags::None; + } + + TokenKind Expected[] = { +#define ROOT_DESCRIPTOR_FLAG_ENUM(NAME, LIT) TokenKind::en_##NAME, +#include "clang/Lex/HLSLRootSignatureTokenKinds.def" + }; + + std::optional<RootDescriptorFlags> Flags; + + do { + if (tryConsumeExpectedToken(Expected)) { + switch (CurToken.TokKind) { +#define ROOT_DESCRIPTOR_FLAG_ENUM(NAME, LIT) \ + case TokenKind::en_##NAME: \ + Flags = \ + maybeOrFlag<RootDescriptorFlags>(Flags, RootDescriptorFlags::NAME); \ + break; +#include "clang/Lex/HLSLRootSignatureTokenKinds.def" + default: + llvm_unreachable("Switch for consumed enum token was not provided"); + } + } + } while (tryConsumeExpectedToken(TokenKind::pu_or)); + + return Flags; +} + std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags> RootSignatureParser::parseDescriptorRangeFlags() { assert(CurToken.TokKind == TokenKind::pu_equal && diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp index 02bf38dcb110f..7ed286589f8fa 100644 --- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp +++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp @@ -347,8 +347,11 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) { TEST_F(ParseHLSLRootSignatureTest, ValidParseRootParamsTest) { const llvm::StringLiteral Source = R"cc( CBV(b0), - SRV(space = 4, t42, visibility = SHADER_VISIBILITY_GEOMETRY), - UAV(visibility = SHADER_VISIBILITY_HULL, u34893247) + SRV(space = 4, t42, visibility = SHADER_VISIBILITY_GEOMETRY, + flags = DATA_VOLATILE | DATA_STATIC | DATA_STATIC_WHILE_SET_AT_EXECUTE + ), + UAV(visibility = SHADER_VISIBILITY_HULL, u34893247), + CBV(b0, flags = 0), )cc"; TrivialModuleLoader ModLoader; @@ -364,7 +367,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootParamsTest) { ASSERT_FALSE(Parser.parse()); - ASSERT_EQ(Elements.size(), 3u); + ASSERT_EQ(Elements.size(), 4u); RootElement Elem = Elements[0]; ASSERT_TRUE(std::holds_alternative<RootParam>(Elem)); @@ -372,6 +375,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootParamsTest) { ASSERT_EQ(std::get<RootParam>(Elem).Reg.Number, 0u); ASSERT_EQ(std::get<RootParam>(Elem).Space, 0u); ASSERT_EQ(std::get<RootParam>(Elem).Visibility, ShaderVisibility::All); + ASSERT_EQ(std::get<RootParam>(Elem).Flags, + RootDescriptorFlags::DataStaticWhileSetAtExecute); Elem = Elements[1]; ASSERT_TRUE(std::holds_alternative<RootParam>(Elem)); @@ -380,6 +385,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootParamsTest) { ASSERT_EQ(std::get<RootParam>(Elem).Reg.Number, 42u); ASSERT_EQ(std::get<RootParam>(Elem).Space, 4u); ASSERT_EQ(std::get<RootParam>(Elem).Visibility, ShaderVisibility::Geometry); + ASSERT_EQ(std::get<RootParam>(Elem).Flags, RootDescriptorFlags::ValidFlags); Elem = Elements[2]; ASSERT_TRUE(std::holds_alternative<RootParam>(Elem)); @@ -388,6 +394,14 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootParamsTest) { ASSERT_EQ(std::get<RootParam>(Elem).Reg.Number, 34893247u); ASSERT_EQ(std::get<RootParam>(Elem).Space, 0u); ASSERT_EQ(std::get<RootParam>(Elem).Visibility, ShaderVisibility::Hull); + ASSERT_EQ(std::get<RootParam>(Elem).Flags, RootDescriptorFlags::DataVolatile); + + Elem = Elements[3]; + ASSERT_EQ(std::get<RootParam>(Elem).Reg.ViewType, RegisterType::BReg); + ASSERT_EQ(std::get<RootParam>(Elem).Reg.Number, 0u); + ASSERT_EQ(std::get<RootParam>(Elem).Space, 0u); + ASSERT_EQ(std::get<RootParam>(Elem).Visibility, ShaderVisibility::All); + ASSERT_EQ(std::get<RootParam>(Elem).Flags, RootDescriptorFlags::None); ASSERT_TRUE(Consumer->isSatisfied()); } diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h index 7aa55215abae3..98fa5f09429e3 100644 --- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h +++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h @@ -46,6 +46,14 @@ enum class RootFlags : uint32_t { ValidFlags = 0x00000fff }; +enum class RootDescriptorFlags : unsigned { + None = 0, + DataVolatile = 0x2, + DataStaticWhileSetAtExecute = 0x4, + DataStatic = 0x8, + ValidFlags = 0xe, +}; + enum class DescriptorRangeFlags : unsigned { None = 0, DescriptorsVolatile = 0x1, @@ -91,6 +99,23 @@ struct RootParam { Register Reg; uint32_t Space = 0; ShaderVisibility Visibility = ShaderVisibility::All; + RootDescriptorFlags Flags; + + void setDefaultFlags() { + assert(Type != ParamType::Sampler && + "Sampler is not a valid type of ParamType"); + switch (Type) { + case ParamType::CBuffer: + case ParamType::SRV: + Flags = RootDescriptorFlags::DataStaticWhileSetAtExecute; + break; + case ParamType::UAV: + Flags = RootDescriptorFlags::DataVolatile; + break; + case ParamType::Sampler: + break; + } + } }; // Models the end of a descriptor table and stores its visibility `````````` </details> https://github.com/llvm/llvm-project/pull/140152 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits