Author: Finn Plummer Date: 2025-04-25T13:05:30-07:00 New Revision: fecf0742b16dc332c7a75b0a6696f08694943862
URL: https://github.com/llvm/llvm-project/commit/fecf0742b16dc332c7a75b0a6696f08694943862 DIFF: https://github.com/llvm/llvm-project/commit/fecf0742b16dc332c7a75b0a6696f08694943862.diff LOG: [HLSL][RootSignature] Add parsing of DescriptorRangeFlags (#136775) - Defines `parseDescriptorRangeFlags` to establish a pattern of how flags will be parsed - Add corresponding unit tests Part four of implementing #126569 Added: Modified: clang/include/clang/Parse/ParseHLSLRootSignature.h clang/lib/Parse/ParseHLSLRootSignature.cpp clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h Removed: ################################################################################ diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h index d639ca91c002f..d2e8f4dbcfc0c 100644 --- a/clang/include/clang/Parse/ParseHLSLRootSignature.h +++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h @@ -81,6 +81,7 @@ class RootSignatureParser { struct ParsedClauseParams { std::optional<llvm::hlsl::rootsig::Register> Reg; std::optional<uint32_t> Space; + std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags> Flags; }; std::optional<ParsedClauseParams> parseDescriptorTableClauseParams(RootSignatureToken::Kind RegType); @@ -91,11 +92,19 @@ class RootSignatureParser { /// Parsing methods of various enums std::optional<llvm::hlsl::rootsig::ShaderVisibility> parseShaderVisibility(); + std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags> + parseDescriptorRangeFlags(); /// Use NumericLiteralParser to convert CurToken.NumSpelling into a unsigned /// 32-bit integer std::optional<uint32_t> handleUIntLiteral(); + /// Flags may specify the value of '0' to denote that there should be no + /// flags set. + /// + /// Return true if the current int_literal token is '0', otherwise false + bool verifyZeroFlag(); + /// Invoke the Lexer to consume a token and update CurToken with the result void consumeNextToken() { CurToken = Lexer.consumeToken(); } diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp index 8244e91c8f89a..3b9e96017c88d 100644 --- a/clang/lib/Parse/ParseHLSLRootSignature.cpp +++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp @@ -129,6 +129,7 @@ RootSignatureParser::parseDescriptorTableClause() { ExpectedReg = TokenKind::sReg; break; } + Clause.setDefaultFlags(); auto Params = parseDescriptorTableClauseParams(ExpectedReg); if (!Params.has_value()) @@ -147,6 +148,9 @@ RootSignatureParser::parseDescriptorTableClause() { if (Params->Space.has_value()) Clause.Space = Params->Space.value(); + if (Params->Flags.has_value()) + Clause.Flags = Params->Flags.value(); + if (consumeExpectedToken(TokenKind::pu_r_paren, diag::err_hlsl_unexpected_end_of_params, /*param of=*/ParamKind)) @@ -194,6 +198,24 @@ RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) { return std::nullopt; Params.Space = Space; } + + // `flags` `=` DESCRIPTOR_RANGE_FLAGS + if (tryConsumeExpectedToken(TokenKind::kw_flags)) { + if (Params.Flags.has_value()) { + getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param) + << CurToken.TokKind; + return std::nullopt; + } + + if (consumeExpectedToken(TokenKind::pu_equal)) + return std::nullopt; + + auto Flags = parseDescriptorRangeFlags(); + if (!Flags.has_value()) + return std::nullopt; + Params.Flags = Flags; + } + } while (tryConsumeExpectedToken(TokenKind::pu_comma)); return Params; @@ -268,6 +290,54 @@ RootSignatureParser::parseShaderVisibility() { return std::nullopt; } +template <typename FlagType> +static FlagType maybeOrFlag(std::optional<FlagType> Flags, FlagType Flag) { + if (!Flags.has_value()) + return Flag; + + return static_cast<FlagType>(llvm::to_underlying(Flags.value()) | + llvm::to_underlying(Flag)); +} + +std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags> +RootSignatureParser::parseDescriptorRangeFlags() { + assert(CurToken.TokKind == TokenKind::pu_equal && + "Expects to only be invoked starting at given keyword"); + + // Handle the edge-case of '0' to specify no flags set + if (tryConsumeExpectedToken(TokenKind::int_literal)) { + if (!verifyZeroFlag()) { + getDiags().Report(CurToken.TokLoc, diag::err_expected) << "'0'"; + return std::nullopt; + } + return DescriptorRangeFlags::None; + } + + TokenKind Expected[] = { +#define DESCRIPTOR_RANGE_FLAG_ENUM(NAME, LIT, ON) TokenKind::en_##NAME, +#include "clang/Lex/HLSLRootSignatureTokenKinds.def" + }; + + std::optional<DescriptorRangeFlags> Flags; + + do { + if (tryConsumeExpectedToken(Expected)) { + switch (CurToken.TokKind) { +#define DESCRIPTOR_RANGE_FLAG_ENUM(NAME, LIT, ON) \ + case TokenKind::en_##NAME: \ + Flags = \ + maybeOrFlag<DescriptorRangeFlags>(Flags, DescriptorRangeFlags::NAME); \ + break; +#include "clang/Lex/HLSLRootSignatureTokenKinds.def" + default: + llvm_unreachable("Switch for consumed enum token was not provided"); + } + } + } while (tryConsumeExpectedToken(TokenKind::pu_or)); + + return Flags; +} + std::optional<uint32_t> RootSignatureParser::handleUIntLiteral() { // Parse the numeric value and do semantic checks on its specification clang::NumericLiteralParser Literal(CurToken.NumSpelling, CurToken.TokLoc, @@ -290,6 +360,12 @@ std::optional<uint32_t> RootSignatureParser::handleUIntLiteral() { return Val.getExtValue(); } +bool RootSignatureParser::verifyZeroFlag() { + assert(CurToken.TokKind == TokenKind::int_literal); + auto X = handleUIntLiteral(); + return X.has_value() && X.value() == 0; +} + bool RootSignatureParser::peekExpectedToken(TokenKind Expected) { return peekExpectedToken(ArrayRef{Expected}); } diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp index 1d89567509e72..f4baf1580de61 100644 --- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp +++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp @@ -130,10 +130,14 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) { const llvm::StringLiteral Source = R"cc( DescriptorTable( CBV(b0), - SRV(space = 3, t42), + SRV(space = 3, t42, flags = 0), visibility = SHADER_VISIBILITY_PIXEL, Sampler(s987, space = +2), - UAV(u4294967294) + UAV(u4294967294, + flags = Descriptors_Volatile | Data_Volatile + | Data_Static_While_Set_At_Execute | Data_Static + | Descriptors_Static_Keeping_Buffer_Bounds_Checks + ) ), DescriptorTable() )cc"; @@ -159,6 +163,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) { RegisterType::BReg); ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 0u); ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u); + ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags, + DescriptorRangeFlags::DataStaticWhileSetAtExecute); Elem = Elements[1]; ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem)); @@ -167,6 +173,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) { RegisterType::TReg); ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 42u); ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 3u); + ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags, + DescriptorRangeFlags::None); Elem = Elements[2]; ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem)); @@ -175,6 +183,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) { RegisterType::SReg); ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 987u); ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 2u); + ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags, + DescriptorRangeFlags::None); Elem = Elements[3]; ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem)); @@ -183,6 +193,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) { RegisterType::UReg); ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 4294967294u); ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u); + ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags, + DescriptorRangeFlags::ValidFlags); Elem = Elements[4]; ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem)); @@ -199,6 +211,35 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) { ASSERT_TRUE(Consumer->isSatisfied()); } +TEST_F(ParseHLSLRootSignatureTest, ValidSamplerFlagsTest) { + // This test will checks we can set the valid enum for Sampler descriptor + // range flags + const llvm::StringLiteral Source = R"cc( + DescriptorTable(Sampler(s0, flags = DESCRIPTORS_VOLATILE)) + )cc"; + + TrivialModuleLoader ModLoader; + auto PP = createPP(Source, ModLoader); + auto TokLoc = SourceLocation(); + + hlsl::RootSignatureLexer Lexer(Source, TokLoc); + SmallVector<RootElement> Elements; + hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); + + // Test no diagnostics produced + Consumer->setNoDiag(); + + ASSERT_FALSE(Parser.parse()); + + RootElement Elem = Elements[0]; + ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem)); + ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::Sampler); + ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags, + DescriptorRangeFlags::ValidSamplerFlags); + + ASSERT_TRUE(Consumer->isSatisfied()); +} + TEST_F(ParseHLSLRootSignatureTest, ValidTrailingCommaTest) { // This test will checks we can handling trailing commas ',' const llvm::StringLiteral Source = R"cc( @@ -383,4 +424,28 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidLexOverflowedNumberTest) { ASSERT_TRUE(Consumer->isSatisfied()); } +TEST_F(ParseHLSLRootSignatureTest, InvalidNonZeroFlagsTest) { + // This test will check that parsing fails when a non-zero integer literal + // is given to flags + const llvm::StringLiteral Source = R"cc( + DescriptorTable( + CBV(b0, flags = 3) + ) + )cc"; + + TrivialModuleLoader ModLoader; + auto PP = createPP(Source, ModLoader); + auto TokLoc = SourceLocation(); + + hlsl::RootSignatureLexer Lexer(Source, TokLoc); + SmallVector<RootElement> Elements; + hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); + + // Test correct diagnostic produced + Consumer->setExpected(diag::err_expected); + ASSERT_TRUE(Parser.parse()); + + ASSERT_TRUE(Consumer->isSatisfied()); +} + } // anonymous namespace diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h index d51b853942dd3..0745bce983bb3 100644 --- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h +++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h @@ -23,6 +23,17 @@ namespace rootsig { // Definition of the various enumerations and flags +enum class DescriptorRangeFlags : unsigned { + None = 0, + DescriptorsVolatile = 0x1, + DataVolatile = 0x2, + DataStaticWhileSetAtExecute = 0x4, + DataStatic = 0x8, + DescriptorsStaticKeepingBufferBoundsChecks = 0x10000, + ValidFlags = 0x1000f, + ValidSamplerFlags = DescriptorsVolatile, +}; + enum class ShaderVisibility { All = 0, Vertex = 1, @@ -55,6 +66,22 @@ struct DescriptorTableClause { ClauseType Type; Register Reg; uint32_t Space = 0; + DescriptorRangeFlags Flags; + + void setDefaultFlags() { + switch (Type) { + case ClauseType::CBuffer: + case ClauseType::SRV: + Flags = DescriptorRangeFlags::DataStaticWhileSetAtExecute; + break; + case ClauseType::UAV: + Flags = DescriptorRangeFlags::DataVolatile; + break; + case ClauseType::Sampler: + Flags = DescriptorRangeFlags::None; + break; + } + } }; // Models RootElement : DescriptorTable | DescriptorTableClause _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits