================
@@ -468,84 +468,118 @@ static Error runAOTCompile(StringRef InputFile, 
StringRef OutputFile,
   return createStringError(inconvertibleErrorCode(), "Unsupported arch");
 }
 
+static constexpr char AttrSYCLModuleId[] = "sycl-module-id";
+
 /// SYCL device code module split mode.
 enum class IRSplitMode {
+  SPLIT_PER_TU,     // one module per translation unit
   SPLIT_PER_KERNEL, // one module per kernel
   SPLIT_NONE        // no splitting
 };
 
-/// Parses the value of \p -module-split-mode.
+/// Parses the value of \p --module-split-mode.
 static std::optional<IRSplitMode> convertStringToSplitMode(StringRef S) {
   return StringSwitch<std::optional<IRSplitMode>>(S)
+      .Case("source", IRSplitMode::SPLIT_PER_TU)
       .Case("kernel", IRSplitMode::SPLIT_PER_KERNEL)
       .Case("none", IRSplitMode::SPLIT_NONE)
       .Default(std::nullopt);
 }
 
+static StringRef splitModeToString(IRSplitMode Mode) {
+  switch (Mode) {
+  case IRSplitMode::SPLIT_PER_TU:
+    return "source";
+  case IRSplitMode::SPLIT_PER_KERNEL:
+    return "kernel";
+  case IRSplitMode::SPLIT_NONE:
+    return "none";
+  }
+  llvm_unreachable("bad split mode");
+}
+
 /// Result of splitting a device module: the bitcode file path and the
 /// serialized symbol table for each device image.
 struct SplitModule {
   SmallString<256> ModuleFilePath;
   SmallString<0> Symbols;
 };
 
-static bool isEntryPoint(const Function &F) {
-  return !F.isDeclaration() && F.hasKernelCallingConv();
+static bool isEntryPoint(const Function &F, bool EmitOnlyKernelsAsEntryPoints) 
{
+  if (F.isDeclaration())
+    return false;
+  if (F.hasKernelCallingConv())
+    return true;
+  if (EmitOnlyKernelsAsEntryPoints)
+    return false;
+  // sycl_external functions carry the "sycl-module-id" attribute.
+  // This branch is not reachable while EmitOnlyKernelsAsEntryPoints is
+  // hardcoded to true (see TODO in runSYCLLink).
+  return F.hasFnAttribute(AttrSYCLModuleId);
 }
 
-/// Collect kernel names from \p M and serialize them into a symbol table.
-static SmallString<0> collectSymbols(const Module &M) {
-  SmallVector<StringRef> KernelNames;
+/// Collect entry point names from \p M and serialize them into a symbol table.
+static SmallString<0> collectEntryPoints(const Module &M,
+                                         bool EmitOnlyKernelsAsEntryPoints) {
+  SmallVector<StringRef> Names;
   for (const Function &F : M)
-    if (isEntryPoint(F))
-      KernelNames.push_back(F.getName());
+    if (isEntryPoint(F, EmitOnlyKernelsAsEntryPoints))
+      Names.push_back(F.getName());
   SmallString<0> SymbolData;
-  llvm::offloading::sycl::writeSymbolTable(KernelNames, SymbolData);
+  llvm::offloading::sycl::writeSymbolTable(Names, SymbolData);
   return SymbolData;
 }
 
-/// Splits the fully linked device \p M into one bitcode file per device image
-/// according to \p Mode and returns the list of split images with their symbol
-/// tables.
-///
-/// For SPLIT_NONE, \p LinkedBitcodeFile is returned as-is.
-/// For SPLIT_PER_KERNEL, the module is split into parts such that each part
-/// contains exactly one kernel entry point and its transitive dependencies;
-/// each part is written to a fresh temporary bitcode file.
-static Expected<SmallVector<SplitModule, 0>>
-splitDeviceCode(std::unique_ptr<Module> M, StringRef LinkedBitcodeFile,
-                IRSplitMode Mode, const ArgList &Args) {
-  SmallVector<SplitModule, 0> SplitModules;
+/// Functor passed to splitModuleTransitiveFromEntryPoints. For each input \p 
F,
+/// returns a numeric group ID (if \p F is an entry point) determining which
+/// device image it lands in, or std::nullopt (for non-entry-points).
+/// SPLIT_PER_KERNEL \p Mode gives each kernel its own ID;
+/// SPLIT_PER_TU \p Mode groups kernels by their "sycl-module-id" attribute
+/// value.
+class EntryPointCategorizer {
+public:
+  EntryPointCategorizer(IRSplitMode Mode, bool EmitOnlyKernelsAsEntryPoints)
+      : Mode(Mode), OnlyKernelsAreEntryPoints(EmitOnlyKernelsAsEntryPoints) {}
 
-  if (Mode == IRSplitMode::SPLIT_NONE) {
-    SplitModules.push_back(
-        {SmallString<256>(LinkedBitcodeFile), collectSymbols(*M)});
-    return SplitModules;
-  }
+  std::optional<int> operator()(const Function &F) {
+    if (!isEntryPoint(F, OnlyKernelsAreEntryPoints))
+      return std::nullopt;
 
-  assert(Mode == IRSplitMode::SPLIT_PER_KERNEL);
+    std::string Key;
+    switch (Mode) {
+    case IRSplitMode::SPLIT_PER_KERNEL:
+      Key = F.getName().str();
+      break;
+    case IRSplitMode::SPLIT_PER_TU:
+      Key = F.getFnAttribute(AttrSYCLModuleId).getValueAsString().str();
+      break;
+    case IRSplitMode::SPLIT_NONE:
+      llvm_unreachable("categorizer should not be used for SPLIT_NONE");
----------------
bader wrote:

```suggestion
      llvm_unreachable("categorizer cannot be used for SPLIT_NONE");
```
llvm_unreachable message _**should**_ be stronger than "should not". Please, 
use "cannot" instead of "should not".
Categorizer doesn't support `SPLIT_NONE`.

https://github.com/llvm/llvm-project/pull/196435
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to