https://github.com/Lancern created 
https://github.com/llvm/llvm-project/pull/173724

This patch makes `HeaderIncludes` to also insert a global module fragment 
declaration (`module;`) when inserting a header when all of the following 
conditions are met:

- The source file is a module unit;
- No tokens excluding comments and whitespaces exist before the module 
declaration.

This patch detects the conditions by checking whether the first declaration in 
the source file is a module declaration.

>From 36d602dbb6657254f3e07030976b30fa6d12a7f4 Mon Sep 17 00:00:00 2001
From: Sirui Mu <[email protected]>
Date: Sat, 27 Dec 2025 22:42:50 +0800
Subject: [PATCH] [clang][Tooling] Insert global module fragment during header
 insertion

This patch makes HeaderIncludes to also insert a `module;` declaration when
inserting a header when all of the following conditions are met:

- The source file is a module unit;
- No tokens excluding comments and whitespaces exist before the module
  declaration.

This patch detects the conditions by checking whether the first declaration in
the source file is a module declaration.
---
 .../clang/Tooling/Inclusions/HeaderIncludes.h |  3 +
 .../lib/Tooling/Inclusions/HeaderIncludes.cpp | 76 ++++++++++++++++---
 .../unittests/Tooling/HeaderIncludesTest.cpp  | 60 +++++++++++++++
 3 files changed, 129 insertions(+), 10 deletions(-)

diff --git a/clang/include/clang/Tooling/Inclusions/HeaderIncludes.h 
b/clang/include/clang/Tooling/Inclusions/HeaderIncludes.h
index d5439dd2c84eb..72407e2b12062 100644
--- a/clang/include/clang/Tooling/Inclusions/HeaderIncludes.h
+++ b/clang/include/clang/Tooling/Inclusions/HeaderIncludes.h
@@ -130,6 +130,9 @@ class HeaderIncludes {
   unsigned MaxInsertOffset;
   // True if we find the main-file header in the Code.
   bool MainIncludeFound;
+  // True if header insertion should also insert a C++20 global module fragment
+  // declaration (i.e. a 'module;' declaration).
+  bool ShouldInsertGlobalModuleFragmentDecl;
   IncludeCategoryManager Categories;
   // Record the offset of the end of the last include in each category.
   std::unordered_map<int, int> CategoryEndOffsets;
diff --git a/clang/lib/Tooling/Inclusions/HeaderIncludes.cpp 
b/clang/lib/Tooling/Inclusions/HeaderIncludes.cpp
index e11319e99ba6a..1212af52e1490 100644
--- a/clang/lib/Tooling/Inclusions/HeaderIncludes.cpp
+++ b/clang/lib/Tooling/Inclusions/HeaderIncludes.cpp
@@ -32,6 +32,22 @@ LangOptions createLangOpts() {
   return LangOpts;
 }
 
+// Create a new lexer on the given \p Code and calls \p Callback with the
+// created source manager and lexer. \p Callback must be a callable object that
+// could be invoked with (const SourceManager &, Lexer &). This function 
returns
+// whatever \p Callback returns.
+template <typename F>
+auto withLexer(StringRef FileName, StringRef Code, const IncludeStyle &Style,
+               F &&Callback)
+    -> std::invoke_result_t<F, const SourceManager &, Lexer &> {
+  SourceManagerForFile VirtualSM(FileName, Code);
+  SourceManager &SM = VirtualSM.get();
+  LangOptions LangOpts = createLangOpts();
+  Lexer Lex(SM.getMainFileID(), SM.getBufferOrFake(SM.getMainFileID()), SM,
+            LangOpts);
+  return std::invoke(std::forward<F>(Callback), std::as_const(SM), Lex);
+}
+
 // Returns the offset after skipping a sequence of tokens, matched by \p
 // GetOffsetAfterSequence, from the start of the code.
 // \p GetOffsetAfterSequence should be a function that matches a sequence of
@@ -40,15 +56,13 @@ unsigned getOffsetAfterTokenSequence(
     StringRef FileName, StringRef Code, const IncludeStyle &Style,
     llvm::function_ref<unsigned(const SourceManager &, Lexer &, Token &)>
         GetOffsetAfterSequence) {
-  SourceManagerForFile VirtualSM(FileName, Code);
-  SourceManager &SM = VirtualSM.get();
-  LangOptions LangOpts = createLangOpts();
-  Lexer Lex(SM.getMainFileID(), SM.getBufferOrFake(SM.getMainFileID()), SM,
-            LangOpts);
-  Token Tok;
-  // Get the first token.
-  Lex.LexFromRawLexer(Tok);
-  return GetOffsetAfterSequence(SM, Lex, Tok);
+  return withLexer(FileName, Code, Style,
+                   [&](const SourceManager &SM, Lexer &Lex) {
+                     Token Tok;
+                     // Get the first token.
+                     Lex.LexFromRawLexer(Tok);
+                     return GetOffsetAfterSequence(SM, Lex, Tok);
+                   });
 }
 
 // Check if a sequence of tokens is like "#<Name> <raw_identifier>". If it is,
@@ -190,6 +204,43 @@ unsigned getMaxHeaderInsertionOffset(StringRef FileName, 
StringRef Code,
       });
 }
 
+// Check whether the first declaration in the code is a C++20 module
+// declaration, and it is not preceded by any preprocessor directives.
+bool isFirstDeclModuleDecl(StringRef FileName, StringRef Code,
+                           const IncludeStyle &Style) {
+  return withLexer(
+      FileName, Code, Style, [](const SourceManager &SM, Lexer &Lex) {
+        // Let the lexer skip any comments and whitespaces for us.
+        Lex.SetKeepWhitespaceMode(false);
+        Lex.SetCommentRetentionState(false);
+
+        Token tok;
+        if (Lex.LexFromRawLexer(tok))
+          return false;
+
+        // A module declaration is made up of the following token sequence:
+        //     export? module <ident> ('.' <ident>)* <partition> <attr> ;
+        //
+        // For convenience, we don't actually lex the whole declaration -- it's
+        // enough to distinguish a module declaration to just ensure an <ident>
+        // is following the "module" keyword.
+
+        // Lex the optional "export" keyword.
+        if (tok.is(tok::raw_identifier) && tok.getRawIdentifier() == "export") 
{
+          if (Lex.LexFromRawLexer(tok))
+            return false;
+        }
+
+        // Lex the "module" keyword.
+        if (!tok.is(tok::raw_identifier) ||
+            tok.getRawIdentifier() != "module" || Lex.LexFromRawLexer(tok))
+          return false;
+
+        // Make sure an identifier follows the "module" keyword.
+        return tok.is(tok::raw_identifier);
+      });
+}
+
 inline StringRef trimInclude(StringRef IncludeName) {
   return IncludeName.trim("\"<>");
 }
@@ -306,7 +357,10 @@ HeaderIncludes::HeaderIncludes(StringRef FileName, 
StringRef Code,
       MaxInsertOffset(MinInsertOffset +
                       getMaxHeaderInsertionOffset(
                           FileName, Code.drop_front(MinInsertOffset), Style)),
-      MainIncludeFound(false), Categories(Style, FileName) {
+      MainIncludeFound(false),
+      ShouldInsertGlobalModuleFragmentDecl(
+          isFirstDeclModuleDecl(FileName, Code, Style)),
+      Categories(Style, FileName) {
   // Add 0 for main header and INT_MAX for headers that are not in any
   // category.
   Priorities = {0, INT_MAX};
@@ -414,6 +468,8 @@ HeaderIncludes::insert(llvm::StringRef IncludeName, bool 
IsAngled,
   // newline should be added.
   if (InsertOffset == Code.size() && (!Code.empty() && Code.back() != '\n'))
     NewInclude = "\n" + NewInclude;
+  if (ShouldInsertGlobalModuleFragmentDecl)
+    NewInclude = "module;\n" + NewInclude;
   return tooling::Replacement(FileName, InsertOffset, 0, NewInclude);
 }
 
diff --git a/clang/unittests/Tooling/HeaderIncludesTest.cpp 
b/clang/unittests/Tooling/HeaderIncludesTest.cpp
index befe4a3dd5a8a..df15ab33c9686 100644
--- a/clang/unittests/Tooling/HeaderIncludesTest.cpp
+++ b/clang/unittests/Tooling/HeaderIncludesTest.cpp
@@ -594,6 +594,66 @@ TEST_F(HeaderIncludesTest, CanDeleteAfterCode) {
   EXPECT_EQ(Expected, remove(Code, "\"b.h\""));
 }
 
+TEST_F(HeaderIncludesTest, InsertGlobalModuleFragmentDeclInterfaceUnit) {
+  // Ensure the header insertion comes with a global module fragment decl (i.e.
+  // a 'module;' line) when:
+  //     - the input file is an module interface unit, and
+  //     - no tokens excluding comments and whitespaces exist before the module
+  //       declaration.
+  std::string Code = R"cpp(// comments
+
+// more comments
+
+export module foo;
+
+int main() {
+    std::vector<int> ints {};
+})cpp";
+  std::string Expected = R"cpp(// comments
+
+// more comments
+
+module;
+#include <vector>
+export module foo;
+
+int main() {
+    std::vector<int> ints {};
+})cpp";
+
+  EXPECT_EQ(Expected, insert(Code, "<vector>"));
+}
+
+TEST_F(HeaderIncludesTest, InsertGlobalModuleFragmentDeclImplUnit) {
+  // Ensure the header insertion comes with a global module fragment decl (i.e.
+  // a 'module;' line) when:
+  //     - the input file is an module implementation unit, and
+  //     - no tokens excluding comments and whitespaces exist before the module
+  //       declaration.
+  std::string Code = R"cpp(// comments
+
+// more comments
+
+module foo;
+
+int main() {
+    std::vector<int> ints {};
+})cpp";
+  std::string Expected = R"cpp(// comments
+
+// more comments
+
+module;
+#include <vector>
+module foo;
+
+int main() {
+    std::vector<int> ints {};
+})cpp";
+
+  EXPECT_EQ(Expected, insert(Code, "<vector>"));
+}
+
 TEST_F(HeaderIncludesTest, InsertInGlobalModuleFragment) {
   // Ensure header insertions go only in the global module fragment
   std::string Code = R"cpp(// comments

_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to