================
@@ -468,84 +468,129 @@ static Error runAOTCompile(StringRef InputFile,
StringRef OutputFile,
return createStringError(inconvertibleErrorCode(), "Unsupported arch");
}
+static constexpr char AttrSYCLModuleId[] = "sycl-module-id";
+
/// SYCL device code module split mode.
enum class IRSplitMode {
+ SPLIT_PER_TU, // one module per translation unit
SPLIT_PER_KERNEL, // one module per kernel
SPLIT_NONE // no splitting
};
-/// Parses the value of \p -module-split-mode.
+/// Parses the value of \p --module-split-mode.
static std::optional<IRSplitMode> convertStringToSplitMode(StringRef S) {
return StringSwitch<std::optional<IRSplitMode>>(S)
+ .Case("source", IRSplitMode::SPLIT_PER_TU)
.Case("kernel", IRSplitMode::SPLIT_PER_KERNEL)
.Case("none", IRSplitMode::SPLIT_NONE)
.Default(std::nullopt);
}
+static StringRef splitModeToString(IRSplitMode Mode) {
+ switch (Mode) {
+ case IRSplitMode::SPLIT_PER_TU:
+ return "source";
+ case IRSplitMode::SPLIT_PER_KERNEL:
+ return "kernel";
+ case IRSplitMode::SPLIT_NONE:
+ return "none";
+ }
+ llvm_unreachable("bad split mode");
+}
+
/// Result of splitting a device module: the bitcode file path and the
/// serialized symbol table for each device image.
struct SplitModule {
SmallString<256> ModuleFilePath;
SmallString<0> Symbols;
};
-static bool isEntryPoint(const Function &F) {
- return !F.isDeclaration() && F.hasKernelCallingConv();
+static bool isEntryPoint(const Function &F, bool EmitOnlyKernelsAsEntryPoints)
{
+ if (F.isDeclaration())
+ return false;
+ if (F.hasKernelCallingConv())
+ return true;
+ if (EmitOnlyKernelsAsEntryPoints)
+ return false;
+ // sycl_external functions carry the "sycl-module-id" attribute.
+ return F.hasFnAttribute(AttrSYCLModuleId);
}
-/// Collect kernel names from \p M and serialize them into a symbol table.
-static SmallString<0> collectSymbols(const Module &M) {
- SmallVector<StringRef> KernelNames;
+/// Collect entry point names from \p M and serialize them into a symbol table.
+static SmallString<0> collectSymbols(const Module &M,
+ bool EmitOnlyKernelsAsEntryPoints) {
+ SmallVector<StringRef> Names;
for (const Function &F : M)
- if (isEntryPoint(F))
- KernelNames.push_back(F.getName());
+ if (isEntryPoint(F, EmitOnlyKernelsAsEntryPoints))
+ Names.push_back(F.getName());
SmallString<0> SymbolData;
- llvm::offloading::sycl::writeSymbolTable(KernelNames, SymbolData);
+ llvm::offloading::sycl::writeSymbolTable(Names, SymbolData);
return SymbolData;
}
+class EntryPointCategorizer {
+public:
+ EntryPointCategorizer(IRSplitMode Mode, bool EmitOnlyKernelsAsEntryPoints)
+ : Mode(Mode), OnlyKernelsAreEntryPoints(EmitOnlyKernelsAsEntryPoints) {}
+
+ std::optional<int> operator()(const Function &F) {
----------------
YuriPlyakhin wrote:
39f9343ac0df15b2789075f192c58925cc5d4879
https://github.com/llvm/llvm-project/pull/196435
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits