https://github.com/ivanradanov created https://github.com/llvm/llvm-project/pull/198102
AsyncInfoTy STATIC_NON_BLOCKING type. Strided array copies and mapping. No create mapping type. Refactoring intialization. Loading offload objects with OpenACC offloading kind. --- <sub>Stack created with <a href="https://github.com/github/gh-stack">GitHub Stacks CLI</a> • <a href="https://gh.io/stacks-feedback">Give Feedback 💬</a></sub> >From 0e9c3cc14186d97727b599dd48e2bf04ee37146e Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov <[email protected]> Date: Sat, 16 May 2026 07:04:01 -0700 Subject: [PATCH] [offload] Add new features to libompaccsupport for OpenACC AsyncInfoTy STATIC_NON_BLOCKING type. Strided array copies and mapping. No create mapping type. Refactoring intialization. Loading offload objects with OpenACC offloading kind. --- offload/include/OpenMP/Mapping.h | 13 +- offload/include/PluginManager.h | 19 +- offload/include/Shared/Debug.h | 47 ++++- offload/include/device.h | 23 +++ offload/include/omptarget.h | 61 ++++++- offload/include/rtl.h | 4 + offload/libompaccsupport/Mapping.cpp | 72 +++++++- offload/libompaccsupport/OffloadRTL.cpp | 32 ++-- offload/libompaccsupport/PluginManager.cpp | 155 ++++++++++++---- offload/libompaccsupport/device.cpp | 44 ++++- offload/libompaccsupport/exports | 4 + offload/libompaccsupport/interface.cpp | 27 +++ offload/libomptarget/exports | 5 +- offload/libomptarget/interface.cpp | 84 +++++---- offload/libomptarget/omptarget.cpp | 67 +++---- offload/plugins-nextgen/amdgpu/src/rtl.cpp | 14 ++ .../common/include/PluginInterface.h | 45 +++++ .../common/src/PluginInterface.cpp | 171 ++++++++++++++++++ offload/plugins-nextgen/cuda/src/rtl.cpp | 13 ++ offload/plugins-nextgen/host/src/rtl.cpp | 7 + 20 files changed, 733 insertions(+), 174 deletions(-) create mode 100644 offload/libompaccsupport/interface.cpp diff --git a/offload/include/OpenMP/Mapping.h b/offload/include/OpenMP/Mapping.h index e4024abf26690..ce1301474082a 100644 --- a/offload/include/OpenMP/Mapping.h +++ b/offload/include/OpenMP/Mapping.h @@ -667,9 +667,11 @@ struct MappingInfoTy { /// - Data transfer issue fails. TargetPointerResultTy getTargetPointer( HDTTMapAccessorTy &HDTTMap, void *HstPtrBegin, void *HstPtrBase, - int64_t TgtPadding, int64_t Size, map_var_info_t HstPtrName, - bool HasFlagTo, bool HasFlagAlways, bool IsImplicit, bool UpdateRefCount, - bool HasCloseModifier, bool HasPresentModifier, bool HasHoldModifier, + int64_t TgtPadding, + std::variant<int64_t, const NonContigDescTy *> MemInfo, + map_var_info_t HstPtrName, bool HasFlagTo, bool HasFlagAlways, + bool IsImplicit, bool UpdateRefCount, bool HasCloseModifier, + bool HasPresentModifier, bool HasHoldModifier, bool IsNoCreate, AsyncInfoTy &AsyncInfo, HostDataToTargetTy *OwnedTPR = nullptr, bool ReleaseHDTTMap = true, StateInfoTy *StateInfo = nullptr); @@ -712,6 +714,11 @@ struct MappingInfoTy { int associatePtr(void *HstPtrBegin, void *TgtPtrBegin, int64_t Size); int disassociatePtr(void *HstPtrBegin); + void printNonContigCopyInfo(void *TgtPtrBegin, void *HstPtrBegin, + const NonContigDescTy &CopyInfo, bool H2D, + HostDataToTargetTy *Entry, + MappingInfoTy::HDTTMapAccessorTy *HDTTMapPtr); + /// Print information about the transfer from \p HstPtr to \p TgtPtr (or vice /// versa if \p H2D is false). If there is an existing mapping, or if \p Entry /// is set, the associated metadata will be printed as well. diff --git a/offload/include/PluginManager.h b/offload/include/PluginManager.h index 6c6fdebe76dff..4dd9fcd3de733 100644 --- a/offload/include/PluginManager.h +++ b/offload/include/PluginManager.h @@ -50,7 +50,7 @@ struct PluginManager { PluginManager() {} - void init(); + void initPlugins(); void deinit(); @@ -96,18 +96,19 @@ struct PluginManager { // Work around for plugins that call dlopen on shared libraries that call // tgt_register_lib during their initialisation. Stash the pointers in a // vector until the plugins are all initialised and then register them. - bool delayRegisterLib(__tgt_bin_desc *Desc) { + bool delayRegisterLib(std::function<void(__tgt_bin_desc *)> RegisterFunc, + __tgt_bin_desc *Desc) { if (RTLsLoaded) return false; - DelayedBinDesc.push_back(Desc); + DelayedBinDesc.push_back({RegisterFunc, Desc}); return true; } void registerDelayedLibraries() { // Only called by libomptarget constructor RTLsLoaded = true; - for (auto *Desc : DelayedBinDesc) - __tgt_register_lib(Desc); + for (auto &[RegisterFunc, Desc] : DelayedBinDesc) + RegisterFunc(Desc); DelayedBinDesc.clear(); } @@ -152,7 +153,9 @@ struct PluginManager { private: bool RTLsLoaded = false; - llvm::SmallVector<__tgt_bin_desc *> DelayedBinDesc; + llvm::SmallVector< + std::pair<std::function<void(__tgt_bin_desc *)>, __tgt_bin_desc *>> + DelayedBinDesc; // List of all plugins, in use or not. llvm::SmallVector<std::unique_ptr<GenericPluginTy>> Plugins; @@ -183,8 +186,8 @@ struct PluginManager { __tgt_bin_desc *upgradeLegacyEntries(__tgt_bin_desc *Desc); }; -/// Initialize the plugin manager and OpenMP runtime. -void initRuntime(); +/// Initialize the plugin manager. +void initRuntime(bool OffloadEnabled); /// Deinitialize the plugin and delete it. void deinitRuntime(); diff --git a/offload/include/Shared/Debug.h b/offload/include/Shared/Debug.h index 34f09051f41ba..3f60604e3f79e 100644 --- a/offload/include/Shared/Debug.h +++ b/offload/include/Shared/Debug.h @@ -172,6 +172,7 @@ class LLVM_ABI odbg_ostream final : public raw_ostream { uint32_t BaseLevel; bool ShouldPrefixNextString; bool ShouldEmitNewLineOnDestruction; + bool ShouldAbortOnDestruction; bool NeedEndNewLine = false; /// Buffer to reduce interference between different threads @@ -223,17 +224,22 @@ class LLVM_ABI odbg_ostream final : public raw_ostream { public: explicit odbg_ostream(std::string Prefix, raw_ostream &Os, uint32_t BaseLevel, bool ShouldPrefixNextString = true, - bool ShouldEmitNewLineOnDestruction = true) + bool ShouldEmitNewLineOnDestruction = true, + bool ShouldAbortOnDestruction = false) : Prefix(std::move(Prefix)), Os(Os), BaseLevel(BaseLevel), ShouldPrefixNextString(ShouldPrefixNextString), ShouldEmitNewLineOnDestruction(ShouldEmitNewLineOnDestruction), - BufferStrm(Buffer) { + ShouldAbortOnDestruction(ShouldAbortOnDestruction), BufferStrm(Buffer) { SetUnbuffered(); } ~odbg_ostream() final { if (ShouldEmitNewLineOnDestruction && NeedEndNewLine) BufferStrm << '\n'; Os << BufferStrm.str(); + if (ShouldAbortOnDestruction) { + Os.flush(); + abort(); + } } odbg_ostream(const odbg_ostream &) = delete; odbg_ostream &operator=(const odbg_ostream &) = delete; @@ -242,6 +248,7 @@ class LLVM_ABI odbg_ostream final : public raw_ostream { BaseLevel = other.BaseLevel; ShouldPrefixNextString = other.ShouldPrefixNextString; ShouldEmitNewLineOnDestruction = other.ShouldEmitNewLineOnDestruction; + ShouldAbortOnDestruction = other.ShouldAbortOnDestruction; NeedEndNewLine = other.NeedEndNewLine; Muted = other.Muted; BufferStrm << other.BufferStrm.str(); @@ -611,7 +618,8 @@ constexpr const char *ODT_Tool = OLDT_Tool; constexpr const char *ODT_Module = OLDT_Module; constexpr const char *ODT_Interop = "Interop"; -static inline odbg_ostream reportErrorStream() { +static inline odbg_ostream reportErrorStream(bool ShouldAbort, + std::string Prefix) { #ifdef OMPTARGET_DEBUG if (::llvm::offload::debug::isDebugEnabled()) { uint32_t RealLevel = ODL_Error; @@ -619,13 +627,26 @@ static inline odbg_ostream reportErrorStream() { (ODT_Error), RealLevel)) return odbg_ostream{ ::llvm::offload::debug::computePrefix(DEBUG_PREFIX, ODT_Error), - ::llvm::offload::debug::dbgs(), RealLevel}; + ::llvm::offload::debug::dbgs(), + RealLevel, + /*ShouldPrefixNextString=*/true, + /*ShouldEmitNewLineOnDestruction=*/true, + ShouldAbort}; else - return odbg_ostream{"", ::llvm::nulls(), 1}; + return odbg_ostream{"", + ::llvm::nulls(), + 1, + /*ShouldPrefixNextString=*/true, + /*ShouldEmitNewLineOnDestruction=*/true, + ShouldAbort}; } #endif - return odbg_ostream{GETNAME(TARGET_NAME) " error: ", - ::llvm::offload::debug::dbgs(), ODL_Error}; + return odbg_ostream{GETNAME(TARGET_NAME) + Prefix, + ::llvm::offload::debug::dbgs(), + ODL_Error, + /*ShouldPrefixNextString=*/true, + /*ShouldEmitNewLineOnDestruction=*/true, + ShouldAbort}; } #ifdef OMPTARGET_DEBUG @@ -693,8 +714,18 @@ static inline raw_ostream &operator<<(raw_ostream &Os, void *Ptr) { #endif // OMPTARGET_DEBUG +// New REPORT warning macro in the same style as ODBG +#define REPORT_WARN() \ + ::llvm::omp::target::debug::reportErrorStream(/*ShouldAbort=*/false, \ + " warning: ") +// New REPORT error macro in the same style as ODBG +#define REPORT() \ + ::llvm::omp::target::debug::reportErrorStream(/*ShouldAbort=*/false, \ + " error: ") // New REPORT macro in the same style as ODBG -#define REPORT() ::llvm::omp::target::debug::reportErrorStream() +#define REPORT_FATAL() \ + ::llvm::omp::target::debug::reportErrorStream(/*ShouldAbort=*/true, \ + " fatal error: ") } // namespace llvm::omp::target::debug diff --git a/offload/include/device.h b/offload/include/device.h index af103c316c3cf..c9afae32b3a1e 100644 --- a/offload/include/device.h +++ b/offload/include/device.h @@ -81,6 +81,12 @@ struct DeviceTy { /// allocator should be used (host, shared, device). int32_t deleteData(void *TgtPtrBegin, int32_t Kind = TARGET_ALLOC_DEFAULT); + int32_t + submitNonContigData(void *TgtPtrBegin, void *HstPtrBegin, + const NonContigDescTy &CopyInfo, AsyncInfoTy &AsyncInfo, + HostDataToTargetTy *Entry = nullptr, + MappingInfoTy::HDTTMapAccessorTy *HDTTMapPtr = nullptr); + // Data transfer. When AsyncInfo is nullptr, the transfer will be // synchronous. // Copy data from host to device @@ -89,6 +95,12 @@ struct DeviceTy { HostDataToTargetTy *Entry = nullptr, MappingInfoTy::HDTTMapAccessorTy *HDTTMapPtr = nullptr); + int32_t + retrieveNonContigData(void *HstPtrBegin, void *TgtPtrBegin, + const NonContigDescTy &CopyInfo, AsyncInfoTy &AsyncInfo, + HostDataToTargetTy *Entry = nullptr, + MappingInfoTy::HDTTMapAccessorTy *HDTTMapPtr = nullptr); + // Copy data from device back to host int32_t retrieveData(void *HstPtrBegin, void *TgtPtrBegin, int64_t Size, AsyncInfoTy &AsyncInfo, @@ -120,16 +132,27 @@ struct DeviceTy { KernelExtraArgsTy *KernelExtraArgs, AsyncInfoTy &AsyncInfo); + // Enqueues a host function in the asynchronous queue. + int32_t enqueueHostCall(void (*Callback)(void *), void *UserData, + AsyncInfoTy &AsyncInfo); + /// Synchronize device/queue/event based on \p AsyncInfo and return /// OFFLOAD_SUCCESS/OFFLOAD_FAIL when succeeds/fails. int32_t synchronize(AsyncInfoTy &AsyncInfo); + /// Synchronize device/queue/event based on \p AsyncInfo without releasing the + /// queue and return QueueStatusTy::READY / QueueStatusTy::NOT_READY / + /// OFFLOAD_FAIL. + int32_t synchronizeStatic(AsyncInfoTy &AsyncInfo); + /// Query for device/queue/event based completion on \p AsyncInfo in a /// non-blocking manner and return OFFLOAD_SUCCESS/OFFLOAD_FAIL when /// succeeds/fails. Must be called multiple times until AsyncInfo is /// completed and AsyncInfo.isDone() returns true. int32_t queryAsync(AsyncInfoTy &AsyncInfo); + int32_t queryAsyncStatic(AsyncInfoTy &AsyncInfo); + /// Calls the corresponding print device info function in the plugin. bool printDeviceInfo(); diff --git a/offload/include/omptarget.h b/offload/include/omptarget.h index e5d9852ad48a6..57c9c4ddcffcf 100644 --- a/offload/include/omptarget.h +++ b/offload/include/omptarget.h @@ -20,11 +20,15 @@ #include "OpenMP/InternalTypes.h" +#include <atomic> #include <cstddef> #include <cstdint> #include <deque> #include <functional> +#include <map> +#include <mutex> #include <type_traits> +#include <variant> #include "llvm/ADT/SmallVector.h" @@ -117,7 +121,8 @@ struct DeviceTy; /// mistakes. class AsyncInfoTy { public: - enum class SyncTy { BLOCKING, NON_BLOCKING }; + enum class SyncTy { BLOCKING, NON_BLOCKING, STATIC_NON_BLOCKING }; + using PostProcFuncTy = std::function<int()>; private: /// Locations we used in (potentially) asynchronous calls which should live @@ -127,8 +132,8 @@ class AsyncInfoTy { /// Post-processing operations executed after a successful synchronization. /// \note the post-processing function should return OFFLOAD_SUCCESS or /// OFFLOAD_FAIL appropriately. - using PostProcFuncTy = std::function<int()>; llvm::SmallVector<PostProcFuncTy> PostProcessingFunctions; + std::mutex PostProcessingFunctionsMutex; __tgt_async_info AsyncInfo; DeviceTy &Device; @@ -139,30 +144,43 @@ class AsyncInfoTy { AsyncInfoTy(DeviceTy &Device, SyncTy SyncType = SyncTy::BLOCKING) : Device(Device), SyncType(SyncType) {} - ~AsyncInfoTy() { synchronize(); } + ~AsyncInfoTy() { finalize(); } /// Implicit conversion to the __tgt_async_info which is used in the /// plugin interface. operator __tgt_async_info *() { return &AsyncInfo; } - /// Synchronize all pending actions. + /// Finalizes this instance of AsyncInfoTy. /// - /// \note synchronization will be performance in a blocking or non-blocking - /// manner, depending on the SyncType. + /// \note synchronization will be performed only if SyncType is blocking. /// - /// \note if the operations are completed, the registered post-processing - /// functions will be executed once and unregistered afterwards. + /// \note in all SyncType cases, if the operations are completed, the + /// registered post-processing functions will be executed once and + /// unregistered afterwards. /// /// \returns OFFLOAD_FAIL or OFFLOAD_SUCCESS appropriately. + int finalize(); + + /// Synchronize all pending actions. + /// + /// \returns OFFLOAD_FAIL or OFFLOAD_SUCCESS depending on whether an error was + /// encountered. int synchronize(); + /// Queries whether all pending actions are done. This function does not + /// return the queue to the RTL. + /// + /// \returns OFFLOAD_FAIL on error, 0 when the actions are done, and 1 when + /// they are pending. + int query(); + /// Return a void* reference with a lifetime that is at least as long as this /// AsyncInfoTy object. The location can be used as intermediate buffer. void *&getVoidPtrLocation(); /// Check if all asynchronous operations are completed. /// - /// \note only a lightweight check. If needed, use synchronize() to query the + /// \note only a lightweight check. If needed, use finalize() to query the /// status of AsyncInfo before checking. /// /// \returns true if there is no pending asynchronous operations, false @@ -178,6 +196,7 @@ class AsyncInfoTy { static_assert(std::is_convertible_v<FuncTy, PostProcFuncTy>, "Invalid post-processing function type. Please check " "function signature!"); + std::lock_guard<std::mutex> PPFGuard{PostProcessingFunctionsMutex}; PostProcessingFunctions.emplace_back(Function); } @@ -270,6 +289,30 @@ struct __tgt_target_non_contig { uint64_t Stride; }; +struct NonContigDescTy { + llvm::SmallVector<__tgt_target_non_contig, 6> Dims; + + const __tgt_target_non_contig &getDim(unsigned I) { return Dims[I]; } + unsigned getRank() const { return Dims.size(); } + + void mergeContiguousDims() { + int RemovedDim = 0; + for (int I = getRank() - 1; I > 0; --I) { + if (Dims[I].Count * Dims[I].Stride == Dims[I - 1].Stride) + RemovedDim++; + } + Dims.resize(getRank() - RemovedDim); + } + + uint64_t getLastDimCopySize() const { + return Dims.back().Count * Dims.back().Stride; + } + + uint64_t getAllocSize() const { + return (Dims[0].Count + Dims[0].Offset) * Dims[0].Stride; + } +}; + #ifdef __cplusplus extern "C" { #endif diff --git a/offload/include/rtl.h b/offload/include/rtl.h index 38f1dd24011e0..0320f6d15a458 100644 --- a/offload/include/rtl.h +++ b/offload/include/rtl.h @@ -53,4 +53,8 @@ struct TableMap { }; typedef std::map<void *, TableMap> HostPtrToTableMapTy; +namespace llvm::offload { +TableMap *getTableMap(void *HostPtr); +} + #endif diff --git a/offload/libompaccsupport/Mapping.cpp b/offload/libompaccsupport/Mapping.cpp index 1bb2e424bd083..6ca53b042b164 100644 --- a/offload/libompaccsupport/Mapping.cpp +++ b/offload/libompaccsupport/Mapping.cpp @@ -14,6 +14,10 @@ #include "Shared/Debug.h" #include "Shared/Requirements.h" #include "device.h" +#include "omptarget.h" +#include <cstdint> +#include <optional> +#include <variant> using namespace llvm::omp::target::debug; @@ -206,12 +210,22 @@ LookupResult MappingInfoTy::lookupMapping(HDTTMapAccessorTy &HDTTMap, TargetPointerResultTy MappingInfoTy::getTargetPointer( HDTTMapAccessorTy &HDTTMap, void *HstPtrBegin, void *HstPtrBase, - int64_t TgtPadding, int64_t Size, map_var_info_t HstPtrName, bool HasFlagTo, - bool HasFlagAlways, bool IsImplicit, bool UpdateRefCount, - bool HasCloseModifier, bool HasPresentModifier, bool HasHoldModifier, + int64_t TgtPadding, std::variant<int64_t, const NonContigDescTy *> MemInfo, + map_var_info_t HstPtrName, bool HasFlagTo, bool HasFlagAlways, + bool IsImplicit, bool UpdateRefCount, bool HasCloseModifier, + bool HasPresentModifier, bool HasHoldModifier, bool IsNoCreate, AsyncInfoTy &AsyncInfo, HostDataToTargetTy *OwnedTPR, bool ReleaseHDTTMap, StateInfoTy *StateInfo) { + int64_t Size; + const NonContigDescTy *CopyInfo = nullptr; + if (std::holds_alternative<int64_t>(MemInfo)) { + Size = std::get<int64_t>(MemInfo); + } else { + CopyInfo = std::get<const NonContigDescTy *>(MemInfo); + Size = CopyInfo->getAllocSize(); + } + LookupResult LR = lookupMapping(HDTTMap, HstPtrBegin, Size, OwnedTPR); LR.TPR.Flags.IsPresent = true; @@ -250,7 +264,8 @@ TargetPointerResultTy MappingInfoTy::getTargetPointer( LR.TPR.getEntry()->holdRefCountToStr().c_str(), HoldRefCountAction, (HstPtrName) ? getNameFromMapping(HstPtrName).c_str() : "unknown"); LR.TPR.TargetPointer = (void *)Ptr; - } else if ((LR.Flags.ExtendsBefore || LR.Flags.ExtendsAfter) && !IsImplicit) { + } else if ((LR.Flags.ExtendsBefore || LR.Flags.ExtendsAfter) && !IsImplicit && + !IsNoCreate) { // Explicit extension of mapped data - not allowed. MESSAGE("explicit extension not allowed: host address specified is " DPxMOD " (%" PRId64 @@ -290,7 +305,7 @@ TargetPointerResultTy MappingInfoTy::getTargetPointer( MESSAGE("device mapping required by 'present' map type modifier does not " "exist for host address " DPxMOD " (%" PRId64 " bytes)", DPxPTR(HstPtrBegin), Size); - } else if (Size) { + } else if (Size && !IsNoCreate) { // If it is not contained and Size > 0, we should create a new entry for it. LR.TPR.Flags.IsNewEntry = true; uintptr_t TgtAllocBegin = @@ -329,6 +344,12 @@ TargetPointerResultTy MappingInfoTy::getTargetPointer( if (ReleaseHDTTMap) HDTTMap.destroy(); + if (!LR.TPR.isPresent() && IsNoCreate) { + ODBG() << "Mapping for " << HstPtrBegin << " with size " << Size + << " does not exist and no_create is specified: returning."; + return std::move(LR.TPR); + } + // Lambda to check if this pointer was newly allocated on the current region. // This is needed to handle cases when the TO entry is encountered after an // alloc entry for the same pointer. In such cases, the ref-count is already @@ -384,8 +405,14 @@ TargetPointerResultTy MappingInfoTy::getTargetPointer( ODBG(ODT_Mapping) << "Moving " << Size << " bytes (hst:" << HstPtrBegin << ") -> (tgt:" << LR.TPR.TargetPointer << ")"; - int Ret = Device.submitData(LR.TPR.TargetPointer, HstPtrBegin, Size, - AsyncInfo, LR.TPR.getEntry()); + int Ret; + if (CopyInfo) { + Ret = Device.submitNonContigData(LR.TPR.TargetPointer, HstPtrBegin, + *CopyInfo, AsyncInfo, LR.TPR.getEntry()); + } else { + Ret = Device.submitData(LR.TPR.TargetPointer, HstPtrBegin, Size, + AsyncInfo, LR.TPR.getEntry()); + } if (Ret != OFFLOAD_SUCCESS) { REPORT() << "Copying data to device failed."; // We will also return nullptr if the data movement fails because that @@ -555,6 +582,27 @@ int MappingInfoTy::deallocTgtPtrAndEntry(HostDataToTargetTy *Entry, return Ret; } +static void printNonContigCopyInfoImpl(int DeviceId, bool H2D, + void *SrcPtrBegin, void *DstPtrBegin, + const NonContigDescTy &CopyInfo, + HostDataToTargetTy *HT) { + + INFO(OMP_INFOTYPE_DATA_TRANSFER, DeviceId, + "Copying non-contiguous data from %s to %s, %sPtr=" DPxMOD + ", %sPtr=" DPxMOD ", Name=%s\n", + H2D ? "host" : "device", H2D ? "device" : "host", H2D ? "Hst" : "Tgt", + DPxPTR(H2D ? SrcPtrBegin : DstPtrBegin), H2D ? "Tgt" : "Hst", + DPxPTR(H2D ? DstPtrBegin : SrcPtrBegin), + (HT && HT->HstPtrName) ? getNameFromMapping(HT->HstPtrName).c_str() + : "unknown"); + ODBG(ODT_Mapping) << "Non-contiguous data descriptor:\n"; + for (unsigned I = 0; I < CopyInfo.getRank(); I++) + ODBG(ODT_Mapping) << " Dim " << I << " : Offset " + << CopyInfo.Dims[I].Offset << " Count " + << CopyInfo.Dims[I].Count << " Stride " + << CopyInfo.Dims[I].Stride << "\n"; +} + static void printCopyInfoImpl(int DeviceId, bool H2D, void *SrcPtrBegin, void *DstPtrBegin, int64_t Size, HostDataToTargetTy *HT) { @@ -569,6 +617,16 @@ static void printCopyInfoImpl(int DeviceId, bool H2D, void *SrcPtrBegin, : "unknown"); } +void MappingInfoTy::printNonContigCopyInfo( + void *TgtPtrBegin, void *HstPtrBegin, const NonContigDescTy &CopyInfo, + bool H2D, HostDataToTargetTy *Entry, + MappingInfoTy::HDTTMapAccessorTy *HDTTMapPtr) { + auto HDTTMap = + HostDataToTargetMap.getExclusiveAccessor(!!Entry || !!HDTTMapPtr); + printNonContigCopyInfoImpl(Device.DeviceID, H2D, HstPtrBegin, TgtPtrBegin, + CopyInfo, Entry); +} + void MappingInfoTy::printCopyInfo( void *TgtPtrBegin, void *HstPtrBegin, int64_t Size, bool H2D, HostDataToTargetTy *Entry, MappingInfoTy::HDTTMapAccessorTy *HDTTMapPtr) { diff --git a/offload/libompaccsupport/OffloadRTL.cpp b/offload/libompaccsupport/OffloadRTL.cpp index 9b02376609cee..972138a8814ad 100644 --- a/offload/libompaccsupport/OffloadRTL.cpp +++ b/offload/libompaccsupport/OffloadRTL.cpp @@ -23,25 +23,15 @@ using namespace llvm::omp::target::debug; static std::mutex PluginMtx; static uint32_t RefCount = 0; +static bool PluginsInitialized = 0; std::atomic<bool> RTLAlive{false}; std::atomic<int> RTLOngoingSyncs{0}; -/// Check deleted and deprecated features, such as environment variables. -static void checkRuntimeEnvironment() { - const char *ShmemEnvarName = "LIBOMPTARGET_SHARED_MEMORY_SIZE"; - if (std::getenv(ShmemEnvarName)) - MESSAGE("Warning: %s is no longer valid. Please use OpenMP clause " - "'dyn_groupprivate' instead.\n", - ShmemEnvarName); -} - -void initRuntime() { +void initRuntime(bool OffloadEnabled) { std::scoped_lock<decltype(PluginMtx)> Lock(PluginMtx); Profiler::get(); TIMESCOPE(); - checkRuntimeEnvironment(); - if (PM == nullptr) PM = new PluginManager(); @@ -53,17 +43,27 @@ void initRuntime() { llvm::omp::target::ompt::connectLibrary(); #endif - PM->init(); - PM->registerDelayedLibraries(); + if (!OffloadEnabled) + ODBG(ODT_Init) << "Offload is disabled. Skipping plugin initialization"; // RTL initialization is complete RTLAlive = true; } + + // Initialize the plugins if at least one of the calls to this function is + // with OffloadEnabled == true + if (!PluginsInitialized && OffloadEnabled) { + ODBG(ODT_Init) << "Offload is enabled. Initializating plugins"; + PM->initPlugins(); + PM->registerDelayedLibraries(); + PluginsInitialized = true; + } } void deinitRuntime() { std::scoped_lock<decltype(PluginMtx)> Lock(PluginMtx); assert(PM && "Runtime not initialized"); + assert(RefCount != 0 && "Unmatched init and deinit"); if (RefCount == 1) { ODBG(ODT_Deinit) << "Deinit offload library!"; @@ -74,10 +74,12 @@ void deinitRuntime() { << RTLOngoingSyncs.load(); std::this_thread::sleep_for(std::chrono::milliseconds(100)); } + PM->deinit(); delete PM; PM = nullptr; - } + PluginsInitialized = false; + } RefCount--; } diff --git a/offload/libompaccsupport/PluginManager.cpp b/offload/libompaccsupport/PluginManager.cpp index 3241a8ecc764f..3f22f3fc931b7 100644 --- a/offload/libompaccsupport/PluginManager.cpp +++ b/offload/libompaccsupport/PluginManager.cpp @@ -15,6 +15,7 @@ #include "Shared/Debug.h" #include "Shared/Profile.h" #include "device.h" +#include "omptarget.h" #include "llvm/Support/Error.h" #include "llvm/Support/ErrorHandling.h" @@ -30,7 +31,7 @@ PluginManager *PM = nullptr; #define PLUGIN_TARGET(Name) extern "C" GenericPluginTy *createPlugin_##Name(); #include "Shared/Targets.def" -int AsyncInfoTy::synchronize() { +int AsyncInfoTy::finalize() { int Result = OFFLOAD_SUCCESS; if (!isQueueEmpty()) { switch (SyncType) { @@ -44,6 +45,9 @@ int AsyncInfoTy::synchronize() { case SyncTy::NON_BLOCKING: Result = Device.queryAsync(*this); break; + case SyncTy::STATIC_NON_BLOCKING: + Result = Device.queryAsyncStatic(*this); + break; } } @@ -54,6 +58,67 @@ int AsyncInfoTy::synchronize() { return Result; } +static int32_t processPPFs(SmallVector<AsyncInfoTy::PostProcFuncTy> &PPFs) { + for (size_t I = 0; I < PPFs.size(); ++I) + if (int Res = PPFs[I](); Res != OFFLOAD_SUCCESS) + return Res; + return OFFLOAD_SUCCESS; +} + +int AsyncInfoTy::synchronize() { + assert(SyncType == SyncTy::STATIC_NON_BLOCKING); + + // We still have not created a queue, or this specific device does not + // generate queues. + if (isQueueEmpty()) + return OFFLOAD_SUCCESS; + + int Result = OFFLOAD_SUCCESS; + switch (SyncType) { + case SyncTy::BLOCKING: + case SyncTy::NON_BLOCKING: { + // BLOCKING and NON_BLOCKING types return the queue to the RTL after + // synchronization. + Result = Device.synchronize(*this); + assert(AsyncInfo.Queue == nullptr && + "The device plugin should have nulled the queue to indicate there " + "are no outstanding actions!"); + // Run any pending post-processing function registered on this async object. + if (Result == OFFLOAD_SUCCESS && isQueueEmpty()) + Result = runPostProcessing(); + return Result; + } + case SyncTy::STATIC_NON_BLOCKING: { + // STATIC_NON_BLOCKING retains its queue thus more careful handling of the + // post processing functions is required. + + // Collect the enqueued PPFs until this point. + SmallVector<PostProcFuncTy> LocalPPFs; + { + std::lock_guard<std::mutex> PPFGuard{PostProcessingFunctionsMutex}; + std::swap(LocalPPFs, PostProcessingFunctions); + } + Result = Device.synchronizeStatic(*this); + // Run any pending post-processing function collected _before_ we + // synchronize. This is important as between before the synchronization and + // after we could have enqueued more post processing operations, which we + // must not run yet. + if (Result == OFFLOAD_SUCCESS) + Result = processPPFs(LocalPPFs); + + return Result; + } + } + llvm_unreachable("Unexpected SyncType"); +} + +int AsyncInfoTy::query() { + // If we don't have a queue, there are no pending actions. + if (isQueueEmpty()) + return 0; + return Device.queryAsyncStatic(*this); +} + void *&AsyncInfoTy::getVoidPtrLocation() { BufferLocations.push_back(nullptr); return BufferLocations.back(); @@ -62,30 +127,26 @@ void *&AsyncInfoTy::getVoidPtrLocation() { bool AsyncInfoTy::isDone() const { return isQueueEmpty(); } int32_t AsyncInfoTy::runPostProcessing() { - size_t Size = PostProcessingFunctions.size(); - for (size_t I = 0; I < Size; ++I) { - const int Result = PostProcessingFunctions[I](); + // Post-processing procedures might add new procedures themselves, so + // repeatedly process them until we are done. + while (true) { + SmallVector<PostProcFuncTy> LocalPPFs; + { + std::lock_guard<std::mutex> PPFGuard{PostProcessingFunctionsMutex}; + std::swap(LocalPPFs, PostProcessingFunctions); + } + if (LocalPPFs.size() == 0) + return OFFLOAD_SUCCESS; + int32_t Result = processPPFs(LocalPPFs); if (Result != OFFLOAD_SUCCESS) return Result; } - - // Clear the vector up until the last known function, since post-processing - // procedures might add new procedures themselves. - const auto *PrevBegin = PostProcessingFunctions.begin(); - PostProcessingFunctions.erase(PrevBegin, PrevBegin + Size); - - return OFFLOAD_SUCCESS; } bool AsyncInfoTy::isQueueEmpty() const { return AsyncInfo.Queue == nullptr; } -void PluginManager::init() { +void PluginManager::initPlugins() { TIMESCOPE(); - if (OffloadPolicy::isOffloadDisabled()) { - ODBG(ODT_Init) << "Offload is disabled. Skipping plugin initialization"; - return; - } - ODBG(ODT_Init) << "Loading RTLs"; // Attempt to create an instance of each supported plugin. @@ -179,13 +240,6 @@ void PluginManager::initializeAllDevices() { initializeDevice(Plugin, DeviceId); } } - // After all plugins are initialized, register atExit cleanup handlers - std::atexit([]() { - // Interop cleanup should be done before the plugins are deinitialized as - // the backend libraries may be already unloaded. - if (PM) - PM->InteropTbl.clear(); - }); } // Returns a pointer to the binary descriptor, upgrading from a legacy format if @@ -365,10 +419,6 @@ void PluginManager::registerLib(__tgt_bin_desc *Desc) { ODBG(ODT_Init) << "Done registering entries!"; } -// Temporary forward declaration, old style CTor/DTor handling is going away. -int target(ident_t *Loc, DeviceTy &Device, void *HostPtr, - KernelArgsTy &KernelArgs, AsyncInfoTy &AsyncInfo); - void PluginManager::unregisterLib(__tgt_bin_desc *Desc) { ODBG(ODT_Deinit) << "Unloading target library!"; @@ -413,7 +463,8 @@ void PluginManager::unregisterLib(__tgt_bin_desc *Desc) { PM->TblMapMtx.lock(); for (llvm::offloading::EntryTy *Cur = Desc->HostEntriesBegin; Cur < Desc->HostEntriesEnd; ++Cur) { - if (Cur->Kind == object::OffloadKind::OFK_OpenMP) + if (Cur->Kind == object::OffloadKind::OFK_OpenMP || + Cur->Kind == object::OffloadKind::OFK_OpenACC) PM->HostPtrToTableMap.erase(Cur->Address); } @@ -483,7 +534,8 @@ static int loadImagesOntoDevice(DeviceTy &Device) { TransTable->TargetsEntries[DeviceId]; for (llvm::offloading::EntryTy &Entry : llvm::make_range(Img->EntriesBegin, Img->EntriesEnd)) { - if (Entry.Kind != object::OffloadKind::OFK_OpenMP) + if (Entry.Kind != object::OffloadKind::OFK_OpenMP && + Entry.Kind != object::OffloadKind::OFK_OpenACC) continue; __tgt_device_binary &Binary = *BinaryOrErr; @@ -537,7 +589,8 @@ static int loadImagesOntoDevice(DeviceTy &Device) { CurrDeviceEntry != EntryDeviceEnd; CurrDeviceEntry++, CurrHostEntry++) { if (CurrDeviceEntry->Size == 0 || - CurrDeviceEntry->Kind != object::OffloadKind::OFK_OpenMP) + (CurrDeviceEntry->Kind != object::OffloadKind::OFK_OpenMP && + CurrDeviceEntry->Kind != object::OffloadKind::OFK_OpenACC)) continue; assert(CurrDeviceEntry->Size == CurrHostEntry->Size && @@ -560,7 +613,7 @@ static int loadImagesOntoDevice(DeviceTy &Device) { void *DevPtr; Device.retrieveData(&DevPtr, CurrDeviceEntryAddr, sizeof(void *), AsyncInfo, /*Entry=*/nullptr, &HDTTMap); - if (AsyncInfo.synchronize() != OFFLOAD_SUCCESS) + if (AsyncInfo.finalize() != OFFLOAD_SUCCESS) return OFFLOAD_FAIL; CurrDeviceEntryAddr = DevPtr; } @@ -620,3 +673,41 @@ Expected<DeviceTy &> PluginManager::getDevice(uint32_t DeviceNo) { DeviceNo); return *DevicePtr; } + +namespace llvm::offload { +/// Find the table information in the map or look it up in the translation +/// tables. +TableMap *getTableMap(void *HostPtr) { + std::lock_guard<std::mutex> TblMapLock(PM->TblMapMtx); + HostPtrToTableMapTy::iterator TableMapIt = + PM->HostPtrToTableMap.find(HostPtr); + + if (TableMapIt != PM->HostPtrToTableMap.end()) + return &TableMapIt->second; + + // We don't have a map. So search all the registered libraries. + TableMap *TM = nullptr; + std::lock_guard<std::mutex> TrlTblLock(PM->TrlTblMtx); + for (HostEntriesBeginToTransTableTy::iterator Itr = + PM->HostEntriesBeginToTransTable.begin(); + Itr != PM->HostEntriesBeginToTransTable.end(); ++Itr) { + // get the translation table (which contains all the good info). + TranslationTable *TransTable = &Itr->second; + // iterate over all the host table entries to see if we can locate the + // host_ptr. + llvm::offloading::EntryTy *Cur = TransTable->HostTable.EntriesBegin; + for (uint32_t I = 0; Cur < TransTable->HostTable.EntriesEnd; ++Cur, ++I) { + if (Cur->Address != HostPtr) + continue; + // we got a match, now fill the HostPtrToTableMap so that we + // may avoid this search next time. + TM = &(PM->HostPtrToTableMap)[HostPtr]; + TM->Table = TransTable; + TM->Index = I; + return TM; + } + } + + return nullptr; +} +} // namespace llvm::offload diff --git a/offload/libompaccsupport/device.cpp b/offload/libompaccsupport/device.cpp index 546f679353544..010d05a51ab2c 100644 --- a/offload/libompaccsupport/device.cpp +++ b/offload/libompaccsupport/device.cpp @@ -19,7 +19,6 @@ #include "Shared/APITypes.h" #include "Shared/Debug.h" #include "omptarget.h" -#include "private.h" #include "rtl.h" #include "Shared/EnvironmentVar.h" @@ -40,6 +39,8 @@ using namespace llvm::omp::target::ompt; using namespace llvm::omp::target::plugin; using namespace llvm::omp::target::debug; +// TODO disable OMPT if we call from OpenACC + int HostDataToTargetTy::addEventIfNecessary(DeviceTy &Device, AsyncInfoTy &AsyncInfo) const { // First, check if the user disabled atomic map transfer/malloc/dealloc. @@ -123,7 +124,8 @@ setupIndirectCallTable(DeviceTy &Device, __tgt_device_image *Image, Image->EntriesEnd); llvm::SmallVector<std::pair<void *, void *>> IndirectCallTable; for (const auto &Entry : Entries) { - if (Entry.Kind != llvm::object::OffloadKind::OFK_OpenMP || + if ((Entry.Kind != llvm::object::OffloadKind::OFK_OpenMP && + Entry.Kind != llvm::object::OffloadKind::OFK_OpenACC) || Entry.Size == 0 || (!(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT) && !(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT_VTABLE))) @@ -280,6 +282,18 @@ int32_t DeviceTy::submitData(void *TgtPtrBegin, void *HstPtrBegin, int64_t Size, AsyncInfo); } +int32_t +DeviceTy::submitNonContigData(void *TgtPtrBegin, void *HstPtrBegin, + const NonContigDescTy &CopyInfo, + AsyncInfoTy &AsyncInfo, HostDataToTargetTy *Entry, + MappingInfoTy::HDTTMapAccessorTy *HDTTMapPtr) { + if (getInfoLevel() & OMP_INFOTYPE_DATA_TRANSFER) + MappingInfo.printNonContigCopyInfo(TgtPtrBegin, HstPtrBegin, CopyInfo, + /*H2D=*/true, Entry, HDTTMapPtr); + return RTL->data_non_contig_submit_async(RTLDeviceID, TgtPtrBegin, + HstPtrBegin, CopyInfo, AsyncInfo); +} + // Retrieve data from device int32_t DeviceTy::retrieveData(void *HstPtrBegin, void *TgtPtrBegin, int64_t Size, AsyncInfoTy &AsyncInfo, @@ -300,6 +314,17 @@ int32_t DeviceTy::retrieveData(void *HstPtrBegin, void *TgtPtrBegin, AsyncInfo); } +int32_t DeviceTy::retrieveNonContigData( + void *HstPtrBegin, void *TgtPtrBegin, const NonContigDescTy &CopyInfo, + AsyncInfoTy &AsyncInfo, HostDataToTargetTy *Entry, + MappingInfoTy::HDTTMapAccessorTy *HDTTMapPtr) { + if (getInfoLevel() & OMP_INFOTYPE_DATA_TRANSFER) + MappingInfo.printNonContigCopyInfo(TgtPtrBegin, HstPtrBegin, CopyInfo, + /*H2D=*/false, Entry, HDTTMapPtr); + return RTL->data_non_contig_retrieve_async(RTLDeviceID, HstPtrBegin, + TgtPtrBegin, CopyInfo, AsyncInfo); +} + // Copy data from current device to destination device directly int32_t DeviceTy::dataExchange(void *SrcPtr, DeviceTy &DstDev, void *DstPtr, int64_t Size, AsyncInfoTy &AsyncInfo) { @@ -347,7 +372,6 @@ int32_t DeviceTy::notifyDataUnmapped(void *HstPtr) { return OFFLOAD_SUCCESS; } -// Run region on device int32_t DeviceTy::launchKernel(void *TgtEntryPtr, void **TgtVarsPtr, ptrdiff_t *TgtOffsets, KernelArgsTy &KernelArgs, KernelExtraArgsTy *KernelExtraArgs, @@ -356,7 +380,11 @@ int32_t DeviceTy::launchKernel(void *TgtEntryPtr, void **TgtVarsPtr, &KernelArgs, KernelExtraArgs, AsyncInfo); } -// Run region on device +int32_t DeviceTy::enqueueHostCall(void (*Callback)(void *), void *UserData, + AsyncInfoTy &AsyncInfo) { + return RTL->enqueue_host_call(RTLDeviceID, Callback, UserData, AsyncInfo); +} + bool DeviceTy::printDeviceInfo() { RTL->print_device_info(RTLDeviceID); return true; @@ -376,10 +404,18 @@ int32_t DeviceTy::synchronize(AsyncInfoTy &AsyncInfo) { return RTL->synchronize(RTLDeviceID, AsyncInfo); } +int32_t DeviceTy::synchronizeStatic(AsyncInfoTy &AsyncInfo) { + return RTL->synchronize_static(RTLDeviceID, AsyncInfo); +} + int32_t DeviceTy::queryAsync(AsyncInfoTy &AsyncInfo) { return RTL->query_async(RTLDeviceID, AsyncInfo); } +int32_t DeviceTy::queryAsyncStatic(AsyncInfoTy &AsyncInfo) { + return RTL->query_async_static(RTLDeviceID, AsyncInfo); +} + int32_t DeviceTy::createEvent(void **Event) { return RTL->create_event(RTLDeviceID, Event); } diff --git a/offload/libompaccsupport/exports b/offload/libompaccsupport/exports index b67f7cbd8890a..68ae4c0443a98 100644 --- a/offload/libompaccsupport/exports +++ b/offload/libompaccsupport/exports @@ -1,5 +1,9 @@ VERS1.0 { global: + __tgt_rtl_init; + __tgt_rtl_deinit; + __tgt_init_all_rtls; + __tgt_register_rpc_callback; *; local: diff --git a/offload/libompaccsupport/interface.cpp b/offload/libompaccsupport/interface.cpp new file mode 100644 index 0000000000000..f178476d7c5ab --- /dev/null +++ b/offload/libompaccsupport/interface.cpp @@ -0,0 +1,27 @@ +//===-------- interface.cpp -----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PluginManager.h" +#include "omptarget.h" + +EXTERN void __tgt_rtl_init() { initRuntime(/*OffloadEnabled=*/true); } +EXTERN void __tgt_rtl_deinit() { deinitRuntime(); } + +//////////////////////////////////////////////////////////////////////////////// +/// Initialize all available devices without registering any image +EXTERN void __tgt_init_all_rtls() { + assert(PM && "Runtime not initialized"); + PM->initializeAllDevices(); +} + +EXTERN void __tgt_register_rpc_callback(unsigned (*Callback)(void *, + unsigned)) { + for (auto &Plugin : PM->plugins()) + if (Plugin.is_initialized()) + Plugin.getRPCServer().registerCallback(Callback); +} diff --git a/offload/libomptarget/exports b/offload/libomptarget/exports index 1831c43cc5f29..7073d25fabd55 100644 --- a/offload/libomptarget/exports +++ b/offload/libomptarget/exports @@ -1,11 +1,8 @@ VERS1.0 { global: - __tgt_rtl_init; - __tgt_rtl_deinit; __tgt_register_requires; __tgt_register_lib; __tgt_unregister_lib; - __tgt_init_all_rtls; __tgt_target_data_begin; __tgt_target_data_end; __tgt_target_data_update; @@ -79,10 +76,10 @@ VERS1.0 { __tgt_interop_use60; __tgt_interop_release; __tgt_target_sync; - __tgt_register_rpc_callback; __llvmPushCallConfiguration; __llvmPopCallConfiguration; llvmLaunchKernel; + local: *; }; diff --git a/offload/libomptarget/interface.cpp b/offload/libomptarget/interface.cpp index 64ff078e3ec46..b1daeecac757f 100644 --- a/offload/libomptarget/interface.cpp +++ b/offload/libomptarget/interface.cpp @@ -38,6 +38,47 @@ using namespace llvm::omp::target::ompt; #endif using namespace llvm::omp::target::debug; +static std::mutex InitMutex; +static uint32_t InitRefCount = 0; + +/// Check deleted and deprecated features, such as environment variables. +static void checkRuntimeEnvironment() { + const char *ShmemEnvarName = "LIBOMPTARGET_SHARED_MEMORY_SIZE"; + if (std::getenv(ShmemEnvarName)) + MESSAGE("Warning: %s is no longer valid. Please use OpenMP clause " + "'dyn_groupprivate' instead.\n", + ShmemEnvarName); +} + +//////////////////////////////////////////////////////////////////////////////// +/// adds a target shared library to the target execution image +EXTERN void __tgt_register_lib(__tgt_bin_desc *Desc) { + std::scoped_lock<decltype(InitMutex)> Lock(InitMutex); + checkRuntimeEnvironment(); + initRuntime(!OffloadPolicy::isOffloadDisabled()); + if (PM->delayRegisterLib(__tgt_register_lib, Desc)) + return; + + PM->registerLib(Desc); + InitRefCount++; +} + +//////////////////////////////////////////////////////////////////////////////// +/// unloads a target shared library +EXTERN void __tgt_unregister_lib(__tgt_bin_desc *Desc) { + std::scoped_lock<decltype(InitMutex)> Lock(InitMutex); + PM->unregisterLib(Desc); + + if (InitRefCount == 1) { + // Interop cleanup should be done before the plugins are deinitialized as + // the backend libraries may be already unloaded. + if (PM) + PM->InteropTbl.clear(); + } + InitRefCount--; + deinitRuntime(); +} + // If offload is enabled, ensure that device DeviceID has been initialized. // // The return bool indicates if the offload is to the host device @@ -83,34 +124,6 @@ EXTERN void __tgt_register_requires(int64_t Flags) { __PRETTY_FUNCTION__); } -EXTERN void __tgt_rtl_init() { initRuntime(); } -EXTERN void __tgt_rtl_deinit() { deinitRuntime(); } - -//////////////////////////////////////////////////////////////////////////////// -/// adds a target shared library to the target execution image -EXTERN void __tgt_register_lib(__tgt_bin_desc *Desc) { - initRuntime(); - if (PM->delayRegisterLib(Desc)) - return; - - PM->registerLib(Desc); -} - -//////////////////////////////////////////////////////////////////////////////// -/// Initialize all available devices without registering any image -EXTERN void __tgt_init_all_rtls() { - assert(PM && "Runtime not initialized"); - PM->initializeAllDevices(); -} - -//////////////////////////////////////////////////////////////////////////////// -/// unloads a target shared library -EXTERN void __tgt_unregister_lib(__tgt_bin_desc *Desc) { - PM->unregisterLib(Desc); - - deinitRuntime(); -} - template <typename TargetAsyncInfoTy> static inline void targetData(ident_t *Loc, int64_t DeviceId, int32_t ArgNum, void **ArgsBase, @@ -187,7 +200,7 @@ targetData(ident_t *Loc, int64_t DeviceId, int32_t ArgNum, void **ArgsBase, Rc = processAttachEntries(*DeviceOrErr, *StateInfo, AsyncInfo); if (Rc == OFFLOAD_SUCCESS) - Rc = AsyncInfo.synchronize(); + Rc = AsyncInfo.finalize(); } handleTargetOutcome(Rc == OFFLOAD_SUCCESS, Loc); @@ -440,7 +453,7 @@ static inline int targetKernel(ident_t *Loc, int64_t DeviceId, int32_t NumTeams, { // required to show synchronization TIMESCOPE_WITH_DETAILS_AND_IDENT("Runtime: synchronize", "", Loc); if (Rc == OFFLOAD_SUCCESS) - Rc = AsyncInfo.synchronize(); + Rc = AsyncInfo.finalize(); handleTargetOutcome(Rc == OFFLOAD_SUCCESS, Loc); assert(Rc == OFFLOAD_SUCCESS && "__tgt_target_kernel unexpected failure!"); @@ -552,7 +565,7 @@ EXTERN int __tgt_target_kernel_replay( LoopTripCount, AsyncInfo, ReplayOutcome); if (Rc == OFFLOAD_SUCCESS) - Rc = AsyncInfo.synchronize(); + Rc = AsyncInfo.finalize(); if (Rc != OFFLOAD_SUCCESS) { ODBG(ODT_Interface) << "Kernel replay failed in device " << DeviceId; @@ -626,7 +639,7 @@ EXTERN void __tgt_target_nowait_query(void **AsyncHandle) { if (QueryCounter.isAboveThreshold()) AsyncInfo->SyncType = AsyncInfoTy::SyncTy::BLOCKING; - if (AsyncInfo->synchronize()) + if (AsyncInfo->finalize()) FATAL_MESSAGE0(1, "Error while querying the async queue for completion.\n"); // If there are device operations still pending, return immediately without // deallocating the handle and increase the current thread query count. @@ -643,10 +656,3 @@ EXTERN void __tgt_target_nowait_query(void **AsyncHandle) { delete AsyncInfo; *AsyncHandle = nullptr; } - -EXTERN void __tgt_register_rpc_callback(unsigned (*Callback)(void *, - unsigned)) { - for (auto &Plugin : PM->plugins()) - if (Plugin.is_initialized()) - Plugin.getRPCServer().registerCallback(Callback); -} diff --git a/offload/libomptarget/omptarget.cpp b/offload/libomptarget/omptarget.cpp index 6853b7155e3ec..c356c129f0e6b 100644 --- a/offload/libomptarget/omptarget.cpp +++ b/offload/libomptarget/omptarget.cpp @@ -16,6 +16,7 @@ #include "OpenMP/OMPT/Callback.h" #include "OpenMP/OMPT/Interface.h" #include "PluginManager.h" +#include "Shared/APITypes.h" #include "Shared/Debug.h" #include "Shared/EnvironmentVar.h" #include "Shared/Utils.h" @@ -472,8 +473,13 @@ int targetDataBegin(ident_t *Loc, DeviceTy &Device, int32_t ArgNum, if ((ArgTypes[I] & OMP_TGT_MAPTYPE_LITERAL) || (ArgTypes[I] & OMP_TGT_MAPTYPE_PRIVATE)) continue; + + assert(!(ArgTypes[I] & OMP_TGT_MAPTYPE_NON_CONTIG)); + + int64_t DataSize = ArgSizes[I]; + TIMESCOPE_WITH_DETAILS_AND_IDENT( - "HostToDev", "Size=" + std::to_string(ArgSizes[I]) + "B", Loc); + "HostToDev", "Size=" + std::to_string(DataSize) + "B", Loc); if (ArgMappers && ArgMappers[I]) { // Instead of executing the regular path of targetDataBegin, call the // targetDataMapper variant which will call targetDataBegin again @@ -482,7 +488,7 @@ int targetDataBegin(ident_t *Loc, DeviceTy &Device, int32_t ArgNum, << "th argument"; map_var_info_t ArgName = (!ArgNames) ? nullptr : ArgNames[I]; - int Rc = targetDataMapper(Loc, Device, ArgsBase[I], Args[I], ArgSizes[I], + int Rc = targetDataMapper(Loc, Device, ArgsBase[I], Args[I], DataSize, ArgTypes[I], ArgName, ArgMappers[I], AsyncInfo, targetDataBegin, StateInfo); @@ -498,7 +504,6 @@ int targetDataBegin(ident_t *Loc, DeviceTy &Device, int32_t ArgNum, void *HstPtrBegin = Args[I]; void *HstPtrBase = ArgsBase[I]; - int64_t DataSize = ArgSizes[I]; map_var_info_t HstPtrName = (!ArgNames) ? nullptr : ArgNames[I]; // ATTACH map-types are supposed to be handled after all mapping for the @@ -571,10 +576,12 @@ int targetDataBegin(ident_t *Loc, DeviceTy &Device, int32_t ArgNum, // PTR_AND_OBJ entry is handled below, and so the allocation might fail // when HasPresentModifier. PointerTpr = Device.getMappingInfo().getTargetPointer( - HDTTMap, HstPtrBase, HstPtrBase, /*TgtPadding=*/0, sizeof(void *), + HDTTMap, HstPtrBase, HstPtrBase, /*TgtPadding=*/0, + static_cast<int64_t>(sizeof(void *)), /*HstPtrName=*/nullptr, /*HasFlagTo=*/false, /*HasFlagAlways=*/false, IsImplicit, UpdateRef, - HasCloseModifier, HasPresentModifier, HasHoldModifier, AsyncInfo, + HasCloseModifier, HasPresentModifier, HasHoldModifier, + /*IsNoCreate=*/false, AsyncInfo, /*OwnedTPR=*/nullptr, /*ReleaseHDTTMap=*/false); PointerTgtPtrBegin = PointerTpr.TargetPointer; IsHostPtr = PointerTpr.Flags.IsHostPointer; @@ -608,9 +615,10 @@ int targetDataBegin(ident_t *Loc, DeviceTy &Device, int32_t ArgNum, const bool HasFlagAlways = ArgTypes[I] & OMP_TGT_MAPTYPE_ALWAYS; // Note that HDTTMap will be released in getTargetPointer. auto TPR = Device.getMappingInfo().getTargetPointer( - HDTTMap, HstPtrBegin, HstPtrBase, TgtPadding, DataSize, HstPtrName, + HDTTMap, HstPtrBegin, HstPtrBase, TgtPadding, ArgSizes[I], HstPtrName, HasFlagTo, HasFlagAlways, IsImplicit, UpdateRef, HasCloseModifier, - HasPresentModifier, HasHoldModifier, AsyncInfo, PointerTpr.getEntry(), + HasPresentModifier, HasHoldModifier, /*IsNoCreate=*/false, AsyncInfo, + PointerTpr.getEntry(), /*ReleaseHDTTMap=*/true, StateInfo); void *TgtPtrBegin = TPR.TargetPointer; IsHostPtr = TPR.Flags.IsHostPointer; @@ -1068,8 +1076,10 @@ int targetDataEnd(ident_t *Loc, DeviceTy &Device, int32_t ArgNum, continue; } - void *HstPtrBegin = Args[I]; + assert(!(ArgTypes[I] & OMP_TGT_MAPTYPE_NON_CONTIG)); int64_t DataSize = ArgSizes[I]; + + void *HstPtrBegin = Args[I]; bool IsImplicit = ArgTypes[I] & OMP_TGT_MAPTYPE_IMPLICIT; bool UpdateRef = !(ArgTypes[I] & OMP_TGT_MAPTYPE_MEMBER_OF) || (ArgTypes[I] & OMP_TGT_MAPTYPE_PTR_AND_OBJ); @@ -1182,7 +1192,8 @@ int targetDataEnd(ident_t *Loc, DeviceTy &Device, int32_t ArgNum, } } - int Ret = Device.retrieveData(HstPtr, TgtPtr, Size, AsyncInfo, Entry); + int Ret = + Device.retrieveData(HstPtr, TgtPtr, Size, AsyncInfo, TPR.getEntry()); if (Ret != OFFLOAD_SUCCESS) { REPORT() << "Copying data from device failed."; return OFFLOAD_FAIL; @@ -1561,43 +1572,9 @@ static bool isLambdaMapping(int64_t Mapping) { return (Mapping & LambdaMapping) == LambdaMapping; } -namespace { -/// Find the table information in the map or look it up in the translation -/// tables. -TableMap *getTableMap(void *HostPtr) { - std::lock_guard<std::mutex> TblMapLock(PM->TblMapMtx); - HostPtrToTableMapTy::iterator TableMapIt = - PM->HostPtrToTableMap.find(HostPtr); - - if (TableMapIt != PM->HostPtrToTableMap.end()) - return &TableMapIt->second; - - // We don't have a map. So search all the registered libraries. - TableMap *TM = nullptr; - std::lock_guard<std::mutex> TrlTblLock(PM->TrlTblMtx); - for (HostEntriesBeginToTransTableTy::iterator Itr = - PM->HostEntriesBeginToTransTable.begin(); - Itr != PM->HostEntriesBeginToTransTable.end(); ++Itr) { - // get the translation table (which contains all the good info). - TranslationTable *TransTable = &Itr->second; - // iterate over all the host table entries to see if we can locate the - // host_ptr. - llvm::offloading::EntryTy *Cur = TransTable->HostTable.EntriesBegin; - for (uint32_t I = 0; Cur < TransTable->HostTable.EntriesEnd; ++Cur, ++I) { - if (Cur->Address != HostPtr) - continue; - // we got a match, now fill the HostPtrToTableMap so that we - // may avoid this search next time. - TM = &(PM->HostPtrToTableMap)[HostPtr]; - TM->Table = TransTable; - TM->Index = I; - return TM; - } - } - - return nullptr; -} +using llvm::offload::getTableMap; +namespace { /// A class manages private arguments in a target region. class PrivateArgumentManagerTy { /// A data structure for the information of first-private arguments. We can diff --git a/offload/plugins-nextgen/amdgpu/src/rtl.cpp b/offload/plugins-nextgen/amdgpu/src/rtl.cpp index f4e81025a285d..ba3127132b88d 100644 --- a/offload/plugins-nextgen/amdgpu/src/rtl.cpp +++ b/offload/plugins-nextgen/amdgpu/src/rtl.cpp @@ -2684,6 +2684,20 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy { return Plugin::success(); } + /// Query for the completion of the pending operations on the async info. + Expected<QueueStatusTy> + queryAsyncStaticImpl(__tgt_async_info &AsyncInfo) override { + AMDGPUStreamTy *Stream = + reinterpret_cast<AMDGPUStreamTy *>(AsyncInfo.Queue); + assert(Stream && "Invalid stream"); + + auto CompletedOrErr = Stream->query(); + if (!CompletedOrErr) + return CompletedOrErr.takeError(); + + return *CompletedOrErr ? QueueStatusTy::READY : QueueStatusTy::NOT_READY; + } + /// Pin the host buffer and return the device pointer that should be used for /// device transfers. Expected<void *> dataLockImpl(void *HstPtr, int64_t Size) override { diff --git a/offload/plugins-nextgen/common/include/PluginInterface.h b/offload/plugins-nextgen/common/include/PluginInterface.h index 3a354af4775d8..ae147bd09530a 100644 --- a/offload/plugins-nextgen/common/include/PluginInterface.h +++ b/offload/plugins-nextgen/common/include/PluginInterface.h @@ -825,6 +825,8 @@ class PinnedAllocationMapTy { } }; +enum class QueueStatusTy { READY = 0, NOT_READY = 1 }; + /// Class implementing common functionalities of offload devices. Each plugin /// should define the specific device class, derive from this generic one, and /// implement the necessary virtual function members. @@ -925,6 +927,10 @@ struct GenericDeviceTy : public DeviceAllocatorTy { virtual Error queryAsyncImpl(__tgt_async_info &AsyncInfo, bool ReleaseQueue, bool *IsQueueWorkCompleted) = 0; + Expected<QueueStatusTy> queryAsyncStatic(__tgt_async_info *AsyncInfo); + virtual Expected<QueueStatusTy> + queryAsyncStaticImpl(__tgt_async_info &AsyncInfo) = 0; + /// Check whether the architecture supports VA management virtual bool supportVAManagement() const { return false; } @@ -1003,12 +1009,28 @@ struct GenericDeviceTy : public DeviceAllocatorTy { virtual Error dataSubmitImpl(void *TgtPtr, const void *HstPtr, int64_t Size, AsyncInfoWrapperTy &AsyncInfoWrapper) = 0; + /// Submit non-contiguous data to the device (host to device transfer). + Error dataNonContigSubmit(void *TgtPtr, const void *HstPtr, + const NonContigDescTy &CopyInfo, + __tgt_async_info *AsyncInfo); + virtual Error dataNonContigSubmitImpl(void *TgtPtr, const void *HstPtr, + const NonContigDescTy &CopyInfo, + AsyncInfoWrapperTy &AsyncInfoWrapper); + /// Retrieve data from the device (device to host transfer). Error dataRetrieve(void *HstPtr, const void *TgtPtr, int64_t Size, __tgt_async_info *AsyncInfo); virtual Error dataRetrieveImpl(void *HstPtr, const void *TgtPtr, int64_t Size, AsyncInfoWrapperTy &AsyncInfoWrapper) = 0; + /// Retrieve non-contiguous data from the device (device to host transfer). + Error dataNonContigRetrieve(void *HstPtr, const void *TgtPtr, + const NonContigDescTy &CopyInfo, + __tgt_async_info *AsyncInfo); + virtual Error dataNonContigRetrieveImpl(void *HstPtr, const void *TgtPtr, + const NonContigDescTy &CopyInfo, + AsyncInfoWrapperTy &AsyncInfoWrapper); + /// Instert a data fence between previous data operations and the following /// operations if necessary for the device virtual Error dataFence(__tgt_async_info *AsyncInfo) = 0; @@ -1605,10 +1627,22 @@ struct GenericPluginTy { int32_t data_submit_async(int32_t DeviceId, void *TgtPtr, void *HstPtr, int64_t Size, __tgt_async_info *AsyncInfoPtr); + /// Copy non-contiguous data to the given device asynchronously. + int32_t data_non_contig_submit_async(int32_t DeviceId, void *TgtPtr, + void *HstPtr, + const NonContigDescTy &CopyInfo, + __tgt_async_info *AsyncInfoPtr); + /// Copy data from the given device. int32_t data_retrieve(int32_t DeviceId, void *HstPtr, void *TgtPtr, int64_t Size); + /// Copy non-contiguous data from the given device asynchronously. + int32_t data_non_contig_retrieve_async(int32_t DeviceId, void *HstPtr, + void *TgtPtr, + const NonContigDescTy &CopyInfo, + __tgt_async_info *AsyncInfoPtr); + /// Copy data from the given device asynchronously. int32_t data_retrieve_async(int32_t DeviceId, void *HstPtr, void *TgtPtr, int64_t Size, __tgt_async_info *AsyncInfoPtr); @@ -1632,15 +1666,26 @@ struct GenericPluginTy { KernelExtraArgsTy *KernelExtraArgs, __tgt_async_info *AsyncInfoPtr); + /// Enqueue a host call into the asynchronous queue. + int32_t enqueue_host_call(int32_t DeviceId, void (*Callback)(void *), + void *UserData, __tgt_async_info *AsyncInfo); + /// Synchronize an asyncrhonous queue with the plugin runtime. int32_t synchronize(int32_t DeviceId, __tgt_async_info *AsyncInfoPtr); + /// Synchronize an asyncrhonous queue with the plugin runtime without + /// releasing it. + int32_t synchronize_static(int32_t DeviceId, __tgt_async_info *AsyncInfoPtr); + /// Query the current state of an asynchronous queue. int32_t query_async(int32_t DeviceId, __tgt_async_info *AsyncInfoPtr); /// Obtain information about the given device. InfoTreeNode obtain_device_info(int32_t DeviceId); + /// Query the current state of an asynchronous queue. + int32_t query_async_static(int32_t DeviceId, __tgt_async_info *AsyncInfoPtr); + /// Prints information about the given devices supported by the plugin. void print_device_info(int32_t DeviceId); diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp b/offload/plugins-nextgen/common/src/PluginInterface.cpp index 4a8bf8d257344..0b5b357b68433 100644 --- a/offload/plugins-nextgen/common/src/PluginInterface.cpp +++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp @@ -953,6 +953,15 @@ Error GenericDeviceTy::queryAsync(__tgt_async_info *AsyncInfo, return queryAsyncImpl(*AsyncInfo, ReleaseQueue, IsQueueWorkCompleted); } +Expected<QueueStatusTy> +GenericDeviceTy::queryAsyncStatic(__tgt_async_info *AsyncInfo) { + if (!AsyncInfo || !AsyncInfo->Queue) + return Plugin::error(ErrorCode::INVALID_ARGUMENT, + "invalid async info queue"); + + return queryAsyncStaticImpl(*AsyncInfo); +} + Error GenericDeviceTy::memoryVAMap(void **Addr, void *VAddr, size_t *RSize) { return Plugin::error(ErrorCode::UNSUPPORTED, "device does not support VA Management"); @@ -1102,6 +1111,71 @@ Error GenericDeviceTy::dataSubmit(void *TgtPtr, const void *HstPtr, return Err; } +Error GenericDeviceTy::dataNonContigSubmit(void *TgtPtr, const void *HstPtr, + const NonContigDescTy &CopyInfo, + __tgt_async_info *AsyncInfo) { + AsyncInfoWrapperTy AsyncInfoWrapper(*this, AsyncInfo); + + auto Err = + dataNonContigSubmitImpl(TgtPtr, HstPtr, CopyInfo, AsyncInfoWrapper); + AsyncInfoWrapper.finalize(Err); + return Err; +} + +static void dumpContigCopyInfo(const NonContigDescTy &CopyInfo) { + for (unsigned I = 0; I < CopyInfo.getRank(); I++) + ODBG(OLDT_Init) << " Dim " << I << " : Offset " << CopyInfo.Dims[I].Offset + << " Count " << CopyInfo.Dims[I].Count << " Stride " + << CopyInfo.Dims[I].Stride << "\n"; +} + +template <auto CopyFunc, typename DstPtrTy, typename SrcPtrTy> +static Error targetDataNonContiguous(GenericDeviceTy &Device, DstPtrTy DstPtr, + SrcPtrTy SrcPtr, + const NonContigDescTy &CopyInfo, + unsigned CurrentDim, uint64_t Offset, + AsyncInfoWrapperTy &AsyncInfoWrapper) { + ODBG(OLDT_Init) << "Non Contig Copy of Dim " << CurrentDim; + if (CurrentDim == CopyInfo.getRank()) { + ODBG(OLDT_Init) << "Moving non-contiguous chunk with size " + << CopyInfo.getLastDimCopySize() << ", " << (void *)SrcPtr + << " -> " << (void *)DstPtr << ".\n"; + return (Device.*CopyFunc)(DstPtr + Offset, SrcPtr + Offset, + CopyInfo.getLastDimCopySize(), AsyncInfoWrapper); + } + + for (unsigned int I = 0; I < CopyInfo.Dims[CurrentDim].Count; ++I) { + uint64_t CurOffset = + CopyInfo.Dims[CurrentDim].Offset + I * CopyInfo.Dims[CurrentDim].Stride; + // we only need to transfer the first element for the last dimension + // since we've already got a contiguous piece. + if (CurrentDim != CopyInfo.getRank() - 1 || I == 0) { + Error Ret = targetDataNonContiguous<CopyFunc>( + Device, DstPtr, SrcPtr, CopyInfo, CurrentDim + 1, Offset + CurOffset, + AsyncInfoWrapper); + if (Ret) + return Ret; + } + } + + return Error::success(); +} + +Error GenericDeviceTy::dataNonContigSubmitImpl( + void *TgtPtr, const void *HstPtr, const NonContigDescTy &CopyInfo, + AsyncInfoWrapperTy &AsyncInfoWrapper) { + ODBG(OLDT_Init) << "Non contig descriptor:\n"; + dumpContigCopyInfo(CopyInfo); + NonContigDescTy MergedCopyInfo = CopyInfo; + MergedCopyInfo.mergeContiguousDims(); + ODBG(OLDT_Init) << "Merged non contig descriptor:\n"; + dumpContigCopyInfo(MergedCopyInfo); + return targetDataNonContiguous<&GenericDeviceTy::dataSubmitImpl>( + *this, reinterpret_cast<char *>(TgtPtr), + reinterpret_cast<const char *>(HstPtr), MergedCopyInfo, /*CurrentDim=*/0, + /*Offset=*/0, AsyncInfoWrapper); +} + Error GenericDeviceTy::dataRetrieve(void *HstPtr, const void *TgtPtr, int64_t Size, __tgt_async_info *AsyncInfo) { AsyncInfoWrapperTy AsyncInfoWrapper(*this, AsyncInfo); @@ -1111,6 +1185,32 @@ Error GenericDeviceTy::dataRetrieve(void *HstPtr, const void *TgtPtr, return Err; } +Error GenericDeviceTy::dataNonContigRetrieve(void *HstPtr, const void *TgtPtr, + const NonContigDescTy &CopyInfo, + __tgt_async_info *AsyncInfo) { + AsyncInfoWrapperTy AsyncInfoWrapper(*this, AsyncInfo); + + auto Err = + dataNonContigRetrieveImpl(HstPtr, TgtPtr, CopyInfo, AsyncInfoWrapper); + AsyncInfoWrapper.finalize(Err); + return Err; +} + +Error GenericDeviceTy::dataNonContigRetrieveImpl( + void *HstPtr, const void *TgtPtr, const NonContigDescTy &CopyInfo, + AsyncInfoWrapperTy &AsyncInfoWrapper) { + ODBG(OLDT_Init) << "Non contig descriptor:\n"; + dumpContigCopyInfo(CopyInfo); + NonContigDescTy MergedCopyInfo = CopyInfo; + MergedCopyInfo.mergeContiguousDims(); + ODBG(OLDT_Init) << "Merged non contig descriptor:\n"; + dumpContigCopyInfo(MergedCopyInfo); + return targetDataNonContiguous<&GenericDeviceTy::dataRetrieveImpl>( + *this, reinterpret_cast<char *>(HstPtr), + reinterpret_cast<const char *>(TgtPtr), MergedCopyInfo, /*CurrentDim=*/0, + /*Offset=*/0, AsyncInfoWrapper); +} + Error GenericDeviceTy::dataExchange(const void *SrcPtr, GenericDeviceTy &DstDev, void *DstPtr, int64_t Size, __tgt_async_info *AsyncInfo) { @@ -1628,6 +1728,22 @@ int32_t GenericPluginTy::data_submit_async(int32_t DeviceId, void *TgtPtr, return OFFLOAD_SUCCESS; } +int32_t GenericPluginTy::data_non_contig_submit_async( + int32_t DeviceId, void *TgtPtr, void *HstPtr, + const NonContigDescTy &CopyInfo, __tgt_async_info *AsyncInfoPtr) { + auto Err = getDevice(DeviceId).dataNonContigSubmit(TgtPtr, HstPtr, CopyInfo, + AsyncInfoPtr); + if (Err) { + REPORT() << "Failure to copy non-contiguous data from device to host." + << "Pointers: host " + << "= " << HstPtr << ", device = " << TgtPtr << ": " + << toString(std::move(Err)); + return OFFLOAD_FAIL; + } + + return OFFLOAD_SUCCESS; +} + int32_t GenericPluginTy::data_retrieve(int32_t DeviceId, void *HstPtr, void *TgtPtr, int64_t Size) { return data_retrieve_async(DeviceId, HstPtr, TgtPtr, Size, @@ -1649,6 +1765,22 @@ int32_t GenericPluginTy::data_retrieve_async(int32_t DeviceId, void *HstPtr, return OFFLOAD_SUCCESS; } +int32_t GenericPluginTy::data_non_contig_retrieve_async( + int32_t DeviceId, void *HstPtr, void *TgtPtr, + const NonContigDescTy &CopyInfo, __tgt_async_info *AsyncInfoPtr) { + auto Err = getDevice(DeviceId).dataNonContigRetrieve(HstPtr, TgtPtr, CopyInfo, + AsyncInfoPtr); + if (Err) { + REPORT() << "Failure to copy non-contiguous data from device to host." + << "Pointers: host " + << "= " << HstPtr << ", device = " << TgtPtr << ": " + << toString(std::move(Err)); + return OFFLOAD_FAIL; + } + + return OFFLOAD_SUCCESS; +} + int32_t GenericPluginTy::data_exchange(int32_t SrcDeviceId, void *SrcPtr, int32_t DstDeviceId, void *DstPtr, int64_t Size) { @@ -1691,6 +1823,21 @@ int32_t GenericPluginTy::launch_kernel(int32_t DeviceId, void *TgtEntryPtr, return OFFLOAD_SUCCESS; } +int32_t GenericPluginTy::enqueue_host_call(int32_t DeviceId, + void (*Callback)(void *), + void *UserData, + __tgt_async_info *AsyncInfoPtr) { + auto Err = + getDevice(DeviceId).enqueueHostCall(Callback, UserData, AsyncInfoPtr); + if (Err) { + REPORT() << "Failure to enqueue host call in device " << DeviceId << ": " + << toString(std::move(Err)); + return OFFLOAD_FAIL; + } + + return OFFLOAD_SUCCESS; +} + int32_t GenericPluginTy::synchronize(int32_t DeviceId, __tgt_async_info *AsyncInfoPtr) { auto Err = getDevice(DeviceId).synchronize(AsyncInfoPtr); @@ -1703,6 +1850,18 @@ int32_t GenericPluginTy::synchronize(int32_t DeviceId, return OFFLOAD_SUCCESS; } +int32_t GenericPluginTy::synchronize_static(int32_t DeviceId, + __tgt_async_info *AsyncInfoPtr) { + auto Err = getDevice(DeviceId).synchronize(AsyncInfoPtr, false); + if (Err) { + REPORT() << "Failure to synchronize stream " << AsyncInfoPtr->Queue << ": " + << toString(std::move(Err)); + return OFFLOAD_FAIL; + } + + return OFFLOAD_SUCCESS; +} + int32_t GenericPluginTy::query_async(int32_t DeviceId, __tgt_async_info *AsyncInfoPtr) { auto Err = getDevice(DeviceId).queryAsync(AsyncInfoPtr); @@ -1725,6 +1884,18 @@ InfoTreeNode GenericPluginTy::obtain_device_info(int32_t DeviceId) { return std::move(*InfoOrErr); } +int32_t GenericPluginTy::query_async_static(int32_t DeviceId, + __tgt_async_info *AsyncInfoPtr) { + auto Res = getDevice(DeviceId).queryAsyncStatic(AsyncInfoPtr); + if (!Res) { + REPORT() << "Failure to query stream " << AsyncInfoPtr->Queue << ": " + << toString(Res.takeError()); + return OFFLOAD_FAIL; + } + + return static_cast<int32_t>(*Res); +} + void GenericPluginTy::print_device_info(int32_t DeviceId) { if (auto Err = getDevice(DeviceId).printInfo()) REPORT() << "Failure to print device " << DeviceId diff --git a/offload/plugins-nextgen/cuda/src/rtl.cpp b/offload/plugins-nextgen/cuda/src/rtl.cpp index 05fdcb032bd29..42e306361535f 100644 --- a/offload/plugins-nextgen/cuda/src/rtl.cpp +++ b/offload/plugins-nextgen/cuda/src/rtl.cpp @@ -780,6 +780,19 @@ struct CUDADeviceTy : public GenericDeviceTy { return Plugin::check(Res, "error in cuStreamQuery: %s"); } + Expected<QueueStatusTy> + queryAsyncStaticImpl(__tgt_async_info &AsyncInfo) override { + CUstream Stream = reinterpret_cast<CUstream>(AsyncInfo.Queue); + CUresult Res = cuStreamQuery(Stream); + + if (Res == CUDA_ERROR_NOT_READY) + return QueueStatusTy::NOT_READY; + if (Res == CUDA_SUCCESS) + return QueueStatusTy::READY; + + return Plugin::check(Res, "error in cuStreamQuery: %s"); + } + Expected<void *> dataLockImpl(void *HstPtr, int64_t Size) override { // TODO: Register the buffer as CUDA host memory. return HstPtr; diff --git a/offload/plugins-nextgen/host/src/rtl.cpp b/offload/plugins-nextgen/host/src/rtl.cpp index 3f2cc612e4df6..d4c7b95f2f5db 100644 --- a/offload/plugins-nextgen/host/src/rtl.cpp +++ b/offload/plugins-nextgen/host/src/rtl.cpp @@ -328,6 +328,13 @@ struct GenELF64DeviceTy : public GenericDeviceTy { return Plugin::success(); } + /// All functions are already synchronous. No need to do anything on this + /// query function. + Expected<QueueStatusTy> + queryAsyncStaticImpl(__tgt_async_info &AsyncInfo) override { + return QueueStatusTy::READY; + } + /// This plugin does not support interoperability Error initAsyncInfoImpl(AsyncInfoWrapperTy &AsyncInfoWrapper) override { return Plugin::success(); _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
