https://github.com/Lancern created https://github.com/llvm/llvm-project/pull/173724
This patch makes `HeaderIncludes` to also insert a global module fragment declaration (`module;`) when inserting a header when all of the following conditions are met: - The source file is a module unit; - No tokens excluding comments and whitespaces exist before the module declaration. This patch detects the conditions by checking whether the first declaration in the source file is a module declaration. >From 36d602dbb6657254f3e07030976b30fa6d12a7f4 Mon Sep 17 00:00:00 2001 From: Sirui Mu <[email protected]> Date: Sat, 27 Dec 2025 22:42:50 +0800 Subject: [PATCH] [clang][Tooling] Insert global module fragment during header insertion This patch makes HeaderIncludes to also insert a `module;` declaration when inserting a header when all of the following conditions are met: - The source file is a module unit; - No tokens excluding comments and whitespaces exist before the module declaration. This patch detects the conditions by checking whether the first declaration in the source file is a module declaration. --- .../clang/Tooling/Inclusions/HeaderIncludes.h | 3 + .../lib/Tooling/Inclusions/HeaderIncludes.cpp | 76 ++++++++++++++++--- .../unittests/Tooling/HeaderIncludesTest.cpp | 60 +++++++++++++++ 3 files changed, 129 insertions(+), 10 deletions(-) diff --git a/clang/include/clang/Tooling/Inclusions/HeaderIncludes.h b/clang/include/clang/Tooling/Inclusions/HeaderIncludes.h index d5439dd2c84eb..72407e2b12062 100644 --- a/clang/include/clang/Tooling/Inclusions/HeaderIncludes.h +++ b/clang/include/clang/Tooling/Inclusions/HeaderIncludes.h @@ -130,6 +130,9 @@ class HeaderIncludes { unsigned MaxInsertOffset; // True if we find the main-file header in the Code. bool MainIncludeFound; + // True if header insertion should also insert a C++20 global module fragment + // declaration (i.e. a 'module;' declaration). + bool ShouldInsertGlobalModuleFragmentDecl; IncludeCategoryManager Categories; // Record the offset of the end of the last include in each category. std::unordered_map<int, int> CategoryEndOffsets; diff --git a/clang/lib/Tooling/Inclusions/HeaderIncludes.cpp b/clang/lib/Tooling/Inclusions/HeaderIncludes.cpp index e11319e99ba6a..1212af52e1490 100644 --- a/clang/lib/Tooling/Inclusions/HeaderIncludes.cpp +++ b/clang/lib/Tooling/Inclusions/HeaderIncludes.cpp @@ -32,6 +32,22 @@ LangOptions createLangOpts() { return LangOpts; } +// Create a new lexer on the given \p Code and calls \p Callback with the +// created source manager and lexer. \p Callback must be a callable object that +// could be invoked with (const SourceManager &, Lexer &). This function returns +// whatever \p Callback returns. +template <typename F> +auto withLexer(StringRef FileName, StringRef Code, const IncludeStyle &Style, + F &&Callback) + -> std::invoke_result_t<F, const SourceManager &, Lexer &> { + SourceManagerForFile VirtualSM(FileName, Code); + SourceManager &SM = VirtualSM.get(); + LangOptions LangOpts = createLangOpts(); + Lexer Lex(SM.getMainFileID(), SM.getBufferOrFake(SM.getMainFileID()), SM, + LangOpts); + return std::invoke(std::forward<F>(Callback), std::as_const(SM), Lex); +} + // Returns the offset after skipping a sequence of tokens, matched by \p // GetOffsetAfterSequence, from the start of the code. // \p GetOffsetAfterSequence should be a function that matches a sequence of @@ -40,15 +56,13 @@ unsigned getOffsetAfterTokenSequence( StringRef FileName, StringRef Code, const IncludeStyle &Style, llvm::function_ref<unsigned(const SourceManager &, Lexer &, Token &)> GetOffsetAfterSequence) { - SourceManagerForFile VirtualSM(FileName, Code); - SourceManager &SM = VirtualSM.get(); - LangOptions LangOpts = createLangOpts(); - Lexer Lex(SM.getMainFileID(), SM.getBufferOrFake(SM.getMainFileID()), SM, - LangOpts); - Token Tok; - // Get the first token. - Lex.LexFromRawLexer(Tok); - return GetOffsetAfterSequence(SM, Lex, Tok); + return withLexer(FileName, Code, Style, + [&](const SourceManager &SM, Lexer &Lex) { + Token Tok; + // Get the first token. + Lex.LexFromRawLexer(Tok); + return GetOffsetAfterSequence(SM, Lex, Tok); + }); } // Check if a sequence of tokens is like "#<Name> <raw_identifier>". If it is, @@ -190,6 +204,43 @@ unsigned getMaxHeaderInsertionOffset(StringRef FileName, StringRef Code, }); } +// Check whether the first declaration in the code is a C++20 module +// declaration, and it is not preceded by any preprocessor directives. +bool isFirstDeclModuleDecl(StringRef FileName, StringRef Code, + const IncludeStyle &Style) { + return withLexer( + FileName, Code, Style, [](const SourceManager &SM, Lexer &Lex) { + // Let the lexer skip any comments and whitespaces for us. + Lex.SetKeepWhitespaceMode(false); + Lex.SetCommentRetentionState(false); + + Token tok; + if (Lex.LexFromRawLexer(tok)) + return false; + + // A module declaration is made up of the following token sequence: + // export? module <ident> ('.' <ident>)* <partition> <attr> ; + // + // For convenience, we don't actually lex the whole declaration -- it's + // enough to distinguish a module declaration to just ensure an <ident> + // is following the "module" keyword. + + // Lex the optional "export" keyword. + if (tok.is(tok::raw_identifier) && tok.getRawIdentifier() == "export") { + if (Lex.LexFromRawLexer(tok)) + return false; + } + + // Lex the "module" keyword. + if (!tok.is(tok::raw_identifier) || + tok.getRawIdentifier() != "module" || Lex.LexFromRawLexer(tok)) + return false; + + // Make sure an identifier follows the "module" keyword. + return tok.is(tok::raw_identifier); + }); +} + inline StringRef trimInclude(StringRef IncludeName) { return IncludeName.trim("\"<>"); } @@ -306,7 +357,10 @@ HeaderIncludes::HeaderIncludes(StringRef FileName, StringRef Code, MaxInsertOffset(MinInsertOffset + getMaxHeaderInsertionOffset( FileName, Code.drop_front(MinInsertOffset), Style)), - MainIncludeFound(false), Categories(Style, FileName) { + MainIncludeFound(false), + ShouldInsertGlobalModuleFragmentDecl( + isFirstDeclModuleDecl(FileName, Code, Style)), + Categories(Style, FileName) { // Add 0 for main header and INT_MAX for headers that are not in any // category. Priorities = {0, INT_MAX}; @@ -414,6 +468,8 @@ HeaderIncludes::insert(llvm::StringRef IncludeName, bool IsAngled, // newline should be added. if (InsertOffset == Code.size() && (!Code.empty() && Code.back() != '\n')) NewInclude = "\n" + NewInclude; + if (ShouldInsertGlobalModuleFragmentDecl) + NewInclude = "module;\n" + NewInclude; return tooling::Replacement(FileName, InsertOffset, 0, NewInclude); } diff --git a/clang/unittests/Tooling/HeaderIncludesTest.cpp b/clang/unittests/Tooling/HeaderIncludesTest.cpp index befe4a3dd5a8a..df15ab33c9686 100644 --- a/clang/unittests/Tooling/HeaderIncludesTest.cpp +++ b/clang/unittests/Tooling/HeaderIncludesTest.cpp @@ -594,6 +594,66 @@ TEST_F(HeaderIncludesTest, CanDeleteAfterCode) { EXPECT_EQ(Expected, remove(Code, "\"b.h\"")); } +TEST_F(HeaderIncludesTest, InsertGlobalModuleFragmentDeclInterfaceUnit) { + // Ensure the header insertion comes with a global module fragment decl (i.e. + // a 'module;' line) when: + // - the input file is an module interface unit, and + // - no tokens excluding comments and whitespaces exist before the module + // declaration. + std::string Code = R"cpp(// comments + +// more comments + +export module foo; + +int main() { + std::vector<int> ints {}; +})cpp"; + std::string Expected = R"cpp(// comments + +// more comments + +module; +#include <vector> +export module foo; + +int main() { + std::vector<int> ints {}; +})cpp"; + + EXPECT_EQ(Expected, insert(Code, "<vector>")); +} + +TEST_F(HeaderIncludesTest, InsertGlobalModuleFragmentDeclImplUnit) { + // Ensure the header insertion comes with a global module fragment decl (i.e. + // a 'module;' line) when: + // - the input file is an module implementation unit, and + // - no tokens excluding comments and whitespaces exist before the module + // declaration. + std::string Code = R"cpp(// comments + +// more comments + +module foo; + +int main() { + std::vector<int> ints {}; +})cpp"; + std::string Expected = R"cpp(// comments + +// more comments + +module; +#include <vector> +module foo; + +int main() { + std::vector<int> ints {}; +})cpp"; + + EXPECT_EQ(Expected, insert(Code, "<vector>")); +} + TEST_F(HeaderIncludesTest, InsertInGlobalModuleFragment) { // Ensure header insertions go only in the global module fragment std::string Code = R"cpp(// comments _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
