llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clang @llvm/pr-subscribers-hlsl Author: Finn Plummer (inbelic) <details> <summary>Changes</summary> - defines the `RootFlags` in-memory enum - defines `parseRootFlags` to parse the various flag enums into a single `uint32_t` - adds corresponding unit tests - improves the diagnostic message for when we provide a non-zero integer value to the flags Resolves https://github.com/llvm/llvm-project/issues/126575 --- Full diff: https://github.com/llvm/llvm-project/pull/138055.diff 7 Files Affected: - (modified) clang/include/clang/Basic/DiagnosticParseKinds.td (+1) - (modified) clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def (+19) - (modified) clang/include/clang/Parse/ParseHLSLRootSignature.h (+1) - (modified) clang/lib/Parse/ParseHLSLRootSignature.cpp (+63-10) - (modified) clang/unittests/Lex/LexHLSLRootSignatureTest.cpp (+14-1) - (modified) clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp (+51-1) - (modified) llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h (+19-2) ``````````diff diff --git a/clang/include/clang/Basic/DiagnosticParseKinds.td b/clang/include/clang/Basic/DiagnosticParseKinds.td index 72e765bcb800d..75ed28f95cd32 100644 --- a/clang/include/clang/Basic/DiagnosticParseKinds.td +++ b/clang/include/clang/Basic/DiagnosticParseKinds.td @@ -1842,5 +1842,6 @@ def err_hlsl_unexpected_end_of_params def err_hlsl_rootsig_repeat_param : Error<"specified the same parameter '%0' multiple times">; def err_hlsl_rootsig_missing_param : Error<"did not specify mandatory parameter '%0'">; def err_hlsl_number_literal_overflow : Error<"integer literal is too large to be represented as a 32-bit %select{signed |}0 integer type">; +def err_hlsl_rootsig_non_zero_flag : Error<"non-zero integer literal specified for flag value">; } // end of Parser diagnostics diff --git a/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def b/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def index ecb8cfc7afa16..eac6ebda84965 100644 --- a/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def +++ b/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def @@ -27,6 +27,9 @@ #endif // Defines the various types of enum +#ifndef ROOT_FLAG_ENUM +#define ROOT_FLAG_ENUM(NAME, LIT) ENUM(NAME, LIT) +#endif #ifndef UNBOUNDED_ENUM #define UNBOUNDED_ENUM(NAME, LIT) ENUM(NAME, LIT) #endif @@ -73,6 +76,7 @@ PUNCTUATOR(minus, '-') // RootElement Keywords: KEYWORD(RootSignature) // used only for diagnostic messaging +KEYWORD(RootFlags) KEYWORD(DescriptorTable) KEYWORD(RootConstants) @@ -100,6 +104,20 @@ UNBOUNDED_ENUM(unbounded, "unbounded") // Descriptor Range Offset Enum: DESCRIPTOR_RANGE_OFFSET_ENUM(DescriptorRangeOffsetAppend, "DESCRIPTOR_RANGE_OFFSET_APPEND") +// Root Flag Enums: +ROOT_FLAG_ENUM(AllowInputAssemblerInputLayout, "ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT") +ROOT_FLAG_ENUM(DenyVertexShaderRootAccess, "DENY_VERTEX_SHADER_ROOT_ACCESS") +ROOT_FLAG_ENUM(DenyHullShaderRootAccess, "DENY_HULL_SHADER_ROOT_ACCESS") +ROOT_FLAG_ENUM(DenyDomainShaderRootAccess, "DENY_DOMAIN_SHADER_ROOT_ACCESS") +ROOT_FLAG_ENUM(DenyGeometryShaderRootAccess, "DENY_GEOMETRY_SHADER_ROOT_ACCESS") +ROOT_FLAG_ENUM(DenyPixelShaderRootAccess, "DENY_PIXEL_SHADER_ROOT_ACCESS") +ROOT_FLAG_ENUM(DenyAmplificationShaderRootAccess, "DENY_AMPLIFICATION_SHADER_ROOT_ACCESS") +ROOT_FLAG_ENUM(DenyMeshShaderRootAccess, "DENY_MESH_SHADER_ROOT_ACCESS") +ROOT_FLAG_ENUM(AllowStreamOutput, "ALLOW_STREAM_OUTPUT") +ROOT_FLAG_ENUM(LocalRootSignature, "LOCAL_ROOT_SIGNATURE") +ROOT_FLAG_ENUM(CBVSRVUAVHeapDirectlyIndexed, "CBV_SRV_UAV_HEAP_DIRECTLY_INDEXED") +ROOT_FLAG_ENUM(SamplerHeapDirectlyIndexed , "SAMPLER_HEAP_DIRECTLY_INDEXED") + // Root Descriptor Flag Enums: ROOT_DESCRIPTOR_FLAG_ENUM(DataVolatile, "DATA_VOLATILE") ROOT_DESCRIPTOR_FLAG_ENUM(DataStaticWhileSetAtExecute, "DATA_STATIC_WHILE_SET_AT_EXECUTE") @@ -127,6 +145,7 @@ SHADER_VISIBILITY_ENUM(Mesh, "SHADER_VISIBILITY_MESH") #undef DESCRIPTOR_RANGE_FLAG_ENUM_OFF #undef DESCRIPTOR_RANGE_FLAG_ENUM_ON #undef ROOT_DESCRIPTOR_FLAG_ENUM +#undef ROOT_FLAG_ENUM #undef DESCRIPTOR_RANGE_OFFSET_ENUM #undef UNBOUNDED_ENUM #undef ENUM diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h index 2ac2083983741..915266f8a36ae 100644 --- a/clang/include/clang/Parse/ParseHLSLRootSignature.h +++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h @@ -71,6 +71,7 @@ class RootSignatureParser { // expected, or, there is a lexing error /// Root Element parse methods: + std::optional<llvm::hlsl::rootsig::RootFlags> parseRootFlags(); std::optional<llvm::hlsl::rootsig::RootConstants> parseRootConstants(); std::optional<llvm::hlsl::rootsig::DescriptorTable> parseDescriptorTable(); std::optional<llvm::hlsl::rootsig::DescriptorTableClause> diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp index a5006b77a6e44..4780af0f94162 100644 --- a/clang/lib/Parse/ParseHLSLRootSignature.cpp +++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp @@ -27,6 +27,13 @@ RootSignatureParser::RootSignatureParser(SmallVector<RootElement> &Elements, bool RootSignatureParser::parse() { // Iterate as many RootElements as possible do { + if (tryConsumeExpectedToken(TokenKind::kw_RootFlags)) { + auto Flags = parseRootFlags(); + if (!Flags.has_value()) + return true; + Elements.push_back(*Flags); + } + if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) { auto Constants = parseRootConstants(); if (!Constants.has_value()) @@ -47,6 +54,61 @@ bool RootSignatureParser::parse() { /*param of=*/TokenKind::kw_RootSignature); } +template <typename FlagType> +static FlagType maybeOrFlag(std::optional<FlagType> Flags, FlagType Flag) { + if (!Flags.has_value()) + return Flag; + + return static_cast<FlagType>(llvm::to_underlying(Flags.value()) | + llvm::to_underlying(Flag)); +} + +std::optional<RootFlags> RootSignatureParser::parseRootFlags() { + assert(CurToken.TokKind == TokenKind::kw_RootFlags && + "Expects to only be invoked starting at given keyword"); + + if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after, + CurToken.TokKind)) + return std::nullopt; + + std::optional<RootFlags> Flags = RootFlags::None; + + // Handle the edge-case of '0' to specify no flags set + if (tryConsumeExpectedToken(TokenKind::int_literal)) { + if (!verifyZeroFlag()) { + getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_non_zero_flag); + return std::nullopt; + } + } else { + // Otherwise, parse as many flags as possible + TokenKind Expected[] = { +#define ROOT_FLAG_ENUM(NAME, LIT) TokenKind::en_##NAME, +#include "clang/Lex/HLSLRootSignatureTokenKinds.def" + }; + + do { + if (tryConsumeExpectedToken(Expected)) { + switch (CurToken.TokKind) { +#define ROOT_FLAG_ENUM(NAME, LIT) \ + case TokenKind::en_##NAME: \ + Flags = maybeOrFlag<RootFlags>(Flags, RootFlags::NAME); \ + break; +#include "clang/Lex/HLSLRootSignatureTokenKinds.def" + default: + llvm_unreachable("Switch for consumed enum token was not provided"); + } + } + } while (tryConsumeExpectedToken(TokenKind::pu_or)); + } + + if (consumeExpectedToken(TokenKind::pu_r_paren, + diag::err_hlsl_unexpected_end_of_params, + /*param of=*/TokenKind::kw_RootFlags)) + return std::nullopt; + + return Flags; +} + std::optional<RootConstants> RootSignatureParser::parseRootConstants() { assert(CurToken.TokKind == TokenKind::kw_RootConstants && "Expects to only be invoked starting at given keyword"); @@ -467,15 +529,6 @@ RootSignatureParser::parseShaderVisibility() { return std::nullopt; } -template <typename FlagType> -static FlagType maybeOrFlag(std::optional<FlagType> Flags, FlagType Flag) { - if (!Flags.has_value()) - return Flag; - - return static_cast<FlagType>(llvm::to_underlying(Flags.value()) | - llvm::to_underlying(Flag)); -} - std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags> RootSignatureParser::parseDescriptorRangeFlags() { assert(CurToken.TokKind == TokenKind::pu_equal && @@ -484,7 +537,7 @@ RootSignatureParser::parseDescriptorRangeFlags() { // Handle the edge-case of '0' to specify no flags set if (tryConsumeExpectedToken(TokenKind::int_literal)) { if (!verifyZeroFlag()) { - getDiags().Report(CurToken.TokLoc, diag::err_expected) << "'0'"; + getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_non_zero_flag); return std::nullopt; } return DescriptorRangeFlags::None; diff --git a/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp b/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp index 89e9a3183ad03..21a1f1f08ae05 100644 --- a/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp +++ b/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp @@ -87,7 +87,7 @@ TEST_F(LexHLSLRootSignatureTest, ValidLexAllTokensTest) { RootSignature - DescriptorTable RootConstants + RootFlags DescriptorTable RootConstants num32BitConstants @@ -98,6 +98,19 @@ TEST_F(LexHLSLRootSignatureTest, ValidLexAllTokensTest) { unbounded DESCRIPTOR_RANGE_OFFSET_APPEND + allow_input_assembler_input_layout + deny_vertex_shader_root_access + deny_hull_shader_root_access + deny_domain_shader_root_access + deny_geometry_shader_root_access + deny_pixel_shader_root_access + deny_amplification_shader_root_access + deny_mesh_shader_root_access + allow_stream_output + local_root_signature + cbv_srv_uav_heap_directly_indexed + sampler_heap_directly_indexed + DATA_VOLATILE DATA_STATIC_WHILE_SET_AT_EXECUTE DATA_STATIC diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp index 150eb3e6e54ef..18e1e517dae8f 100644 --- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp +++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp @@ -294,6 +294,56 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootConsantsTest) { ASSERT_TRUE(Consumer->isSatisfied()); } +TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) { + const llvm::StringLiteral Source = R"cc( + RootFlags(), + RootFlags(0), + RootFlags( + deny_domain_shader_root_access | + deny_pixel_shader_root_access | + local_root_signature | + cbv_srv_uav_heap_directly_indexed | + deny_amplification_shader_root_access | + deny_geometry_shader_root_access | + deny_hull_shader_root_access | + deny_mesh_shader_root_access | + allow_stream_output | + sampler_heap_directly_indexed | + allow_input_assembler_input_layout | + deny_vertex_shader_root_access + ) + )cc"; + + TrivialModuleLoader ModLoader; + auto PP = createPP(Source, ModLoader); + auto TokLoc = SourceLocation(); + + hlsl::RootSignatureLexer Lexer(Source, TokLoc); + SmallVector<RootElement> Elements; + hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); + + // Test no diagnostics produced + Consumer->setNoDiag(); + + ASSERT_FALSE(Parser.parse()); + + ASSERT_EQ(Elements.size(), 3u); + + RootElement Elem = Elements[0]; + ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem)); + ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None); + + Elem = Elements[1]; + ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem)); + ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None); + + Elem = Elements[2]; + ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem)); + ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::ValidFlags); + + ASSERT_TRUE(Consumer->isSatisfied()); +} + TEST_F(ParseHLSLRootSignatureTest, ValidTrailingCommaTest) { // This test will checks we can handling trailing commas ',' const llvm::StringLiteral Source = R"cc( @@ -496,7 +546,7 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidNonZeroFlagsTest) { hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); // Test correct diagnostic produced - Consumer->setExpected(diag::err_expected); + Consumer->setExpected(diag::err_hlsl_rootsig_non_zero_flag); ASSERT_TRUE(Parser.parse()); ASSERT_TRUE(Consumer->isSatisfied()); diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h index 8b8324df18bb3..2ecaf69fc2f9c 100644 --- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h +++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h @@ -23,6 +23,23 @@ namespace rootsig { // Definition of the various enumerations and flags +enum class RootFlags : uint32_t { + None = 0, + AllowInputAssemblerInputLayout = 0x1, + DenyVertexShaderRootAccess = 0x2, + DenyHullShaderRootAccess = 0x4, + DenyDomainShaderRootAccess = 0x8, + DenyGeometryShaderRootAccess = 0x10, + DenyPixelShaderRootAccess = 0x20, + AllowStreamOutput = 0x40, + LocalRootSignature = 0x80, + DenyAmplificationShaderRootAccess = 0x100, + DenyMeshShaderRootAccess = 0x200, + CBVSRVUAVHeapDirectlyIndexed = 0x400, + SamplerHeapDirectlyIndexed = 0x800, + ValidFlags = 0x00000fff +}; + enum class DescriptorRangeFlags : unsigned { None = 0, DescriptorsVolatile = 0x1, @@ -97,8 +114,8 @@ struct DescriptorTableClause { }; // Models RootElement : RootConstants | DescriptorTable | DescriptorTableClause -using RootElement = - std::variant<RootConstants, DescriptorTable, DescriptorTableClause>; +using RootElement = std::variant<RootFlags, RootConstants, DescriptorTable, + DescriptorTableClause>; } // namespace rootsig } // namespace hlsl `````````` </details> https://github.com/llvm/llvm-project/pull/138055 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits