This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f0c28a0791 [RUNTIME][METAL] Fix multithreading access of metal runtime
(#16605)
f0c28a0791 is described below
commit f0c28a0791e15d39dbf380178e869b8f98bd4b37
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Feb 18 18:16:12 2024 -0500
[RUNTIME][METAL] Fix multithreading access of metal runtime (#16605)
This PR fixes a bug where metal runtime cannot be accessed from multiple
threads.
This is because the ThreadLocal entry initialization happens during global
workspace
initialization, meaning other threads that tries to use metal runtime later
cannot
have the thread local entry correctly initialized.
This PR fixes the problem by always use nullptr fallback and lookup at the
global workspace for default stream.
Co-authored-by: tqchen <[email protected]>
---
src/runtime/metal/metal_common.h | 25 ++++++++++++++--------
src/runtime/metal/metal_device_api.mm | 40 +++++++++++------------------------
src/runtime/metal/metal_module.mm | 9 ++++----
3 files changed, 32 insertions(+), 42 deletions(-)
diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h
index dad156bcdd..d9154e0f79 100644
--- a/src/runtime/metal/metal_common.h
+++ b/src/runtime/metal/metal_common.h
@@ -136,10 +136,7 @@ class MetalWorkspace final : public DeviceAPI {
std::vector<id<MTLDevice>> devices;
// Warp size constant
std::vector<int> warp_size;
- // Whether it is initialized.
- bool initialized_{false};
- // the mutex for initialization
- std::mutex mutex;
+ MetalWorkspace();
// Destructor
~MetalWorkspace();
// Get device for given device
@@ -149,9 +146,6 @@ class MetalWorkspace final : public DeviceAPI {
<< "Invalid Metal device_id=" << dev.device_id;
return devices[dev.device_id];
}
- // Initialize workspace
- // Return false if already initialized, otherwise return true.
- void Init();
// override device API
void SetDevice(Device dev) final;
void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final;
@@ -163,7 +157,16 @@ class MetalWorkspace final : public DeviceAPI {
void SetStream(Device dev, TVMStreamHandle stream) final;
void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final;
void FreeWorkspace(Device dev, void* data) final;
- void ReinitializeStreams();
+ void ReinitializeDefaultStreams();
+
+ /**
+ * Cast stream to the right metal stream data structure
+ * if stream is nullptr , return the default stream of device_id
+ * \param stream the input stream handle
+ * \param device_id The device id of interest
+ * \returns The stream used in this function.
+ */
+ Stream* CastStreamOrGetDefault(TVMStreamHandle stream, int device_id);
// get the global workspace
static MetalWorkspace* Global();
@@ -184,7 +187,7 @@ class MetalThreadEntry {
/*! \brief The current device */
Device device;
/*! \brief The current stream */
- std::vector<Stream*> stream;
+ std::vector<TVMStreamHandle> stream;
/*! \brief The shared buffer used for copy. */
std::vector<id<MTLBuffer>> temp_buffer_;
/*! \brief workspace pool */
@@ -193,6 +196,10 @@ class MetalThreadEntry {
MetalThreadEntry() : pool(static_cast<DLDeviceType>(kDLMetal),
MetalWorkspace::Global()) {
device.device_id = 0;
device.device_type = static_cast<DLDeviceType>(kDLMetal);
+ MetalWorkspace* global_ws = MetalWorkspace::Global();
+ // by default, set the stream to nullptr, which indicate
+ // that we are using default stream
+ this->stream.resize(global_ws->devices.size(), nullptr);
}
~MetalThreadEntry();
// Get temp buffer with at least size under dev.
diff --git a/src/runtime/metal/metal_device_api.mm
b/src/runtime/metal/metal_device_api.mm
index c4ffc8943c..e3853ef6d6 100644
--- a/src/runtime/metal/metal_device_api.mm
+++ b/src/runtime/metal/metal_device_api.mm
@@ -42,7 +42,6 @@ MetalWorkspace* MetalWorkspace::Global() {
void MetalWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv)
{
AUTORELEASEPOOL {
- this->Init();
size_t index = static_cast<size_t>(dev.device_id);
if (kind == kExist) {
*rv = int(index < devices.size());
@@ -142,29 +141,18 @@ MetalWorkspace::~MetalWorkspace() {
}
}
-void MetalWorkspace::ReinitializeStreams() {
- std::vector<Stream*>& threadStreams =
MetalThreadEntry::ThreadLocal()->stream;
- ICHECK_EQ(default_streams_.size(), threadStreams.size());
+void MetalWorkspace::ReinitializeDefaultStreams() {
for (size_t i = 0; i < default_streams_.size(); ++i) {
- if (threadStreams[i] != nullptr && default_streams_[i] != threadStreams[i])
- delete threadStreams[i];
delete default_streams_[i];
}
default_streams_.resize(devices.size());
- threadStreams.resize(devices.size());
for (size_t i = 0; i < devices.size(); ++i) {
Stream* stream = new Stream(devices[i]);
default_streams_[i] = stream;
- threadStreams[i] = stream;
}
}
-void MetalWorkspace::Init() {
- if (initialized_) return;
- std::lock_guard<std::mutex> lock(this->mutex);
- if (initialized_) return;
- initialized_ = true;
- if (devices.size() != 0) return;
+MetalWorkspace::MetalWorkspace() {
#if TARGET_OS_IPHONE
// on iPhone
id<MTLDevice> d = MTLCreateSystemDefaultDevice();
@@ -178,7 +166,7 @@ void MetalWorkspace::Init() {
warp_size.push_back(GetWarpSize(d));
}
#endif
- ReinitializeStreams();
+ this->ReinitializeDefaultStreams();
}
void MetalWorkspace::SetDevice(Device dev) {
@@ -189,7 +177,6 @@ void* MetalWorkspace::AllocDataSpace(Device device, size_t
nbytes, size_t alignm
DLDataType type_hint) {
id<MTLBuffer> buf;
AUTORELEASEPOOL {
- this->Init();
id<MTLDevice> dev = GetDevice(device);
// GPU memory only
MTLResourceOptions storage_mode = MTLResourceStorageModePrivate;
@@ -220,20 +207,20 @@ void MetalWorkspace::FreeDataSpace(Device dev, void* ptr)
{
};
}
-Stream* CastStreamOrGetCurrent(TVMStreamHandle stream, int device_id) {
+Stream* MetalWorkspace::CastStreamOrGetDefault(TVMStreamHandle stream, int
device_id) {
if (stream != nullptr) return static_cast<Stream*>(stream);
- ICHECK(MetalThreadEntry::ThreadLocal()->stream[device_id] != nullptr);
- return MetalThreadEntry::ThreadLocal()->stream[device_id];
+ ICHECK_LT(static_cast<size_t>(device_id), default_streams_.size());
+ ICHECK(default_streams_[device_id] != nullptr);
+ return default_streams_[device_id];
}
void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset,
void* to,
size_t to_offset, size_t size, Device
dev_from, Device dev_to,
DLDataType type_hint, TVMStreamHandle
stream) {
AUTORELEASEPOOL {
- this->Init();
Device dev = dev_from;
if (dev_from.device_type == kDLCPU) dev = dev_to;
- Stream* s = CastStreamOrGetCurrent(stream, dev.device_id);
+ Stream* s = this->CastStreamOrGetDefault(stream, dev.device_id);
if (s->HasErrorHappened()) {
LOG(FATAL) << "Error! Some problems on GPU happaned! Cannot copy data to
current stream";
}
@@ -303,15 +290,12 @@ TVMStreamHandle MetalWorkspace::CreateStream(Device dev) {
void MetalWorkspace::FreeStream(Device dev, TVMStreamHandle stream) {
ICHECK(stream != nullptr);
ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " <<
dev.device_id;
- Stream* s = static_cast<Stream*>(stream);
- if (MetalThreadEntry::ThreadLocal()->stream[dev.device_id] == s)
- MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = nullptr;
- delete s;
+ delete static_cast<Stream*>(stream);
}
void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle stream) {
AUTORELEASEPOOL {
- Stream* s = CastStreamOrGetCurrent(stream, dev.device_id);
+ Stream* s = CastStreamOrGetDefault(stream, dev.device_id);
// commit an empty command buffer and wait until it completes.
id<MTLCommandBuffer> cb = s->GetCommandBuffer();
[cb commit];
@@ -325,7 +309,7 @@ void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle
stream) {
void MetalWorkspace::SetStream(Device dev, TVMStreamHandle stream) {
ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " <<
dev.device_id;
ICHECK(stream != nullptr);
- MetalThreadEntry::ThreadLocal()->stream[dev.device_id] =
static_cast<Stream*>(stream);
+ MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = stream;
}
void* MetalWorkspace::AllocWorkspace(Device dev, size_t size, DLDataType
type_hint) {
@@ -374,7 +358,7 @@ TVM_REGISTER_GLOBAL("device_api.metal").set_body([](TVMArgs
args, TVMRetValue* r
});
TVM_REGISTER_GLOBAL("metal.ResetGlobalState").set_body_typed([]() {
- MetalWorkspace::Global()->ReinitializeStreams();
+ MetalWorkspace::Global()->ReinitializeDefaultStreams();
});
} // namespace metal
diff --git a/src/runtime/metal/metal_module.mm
b/src/runtime/metal/metal_module.mm
index 98e32cdf9c..01d1079426 100644
--- a/src/runtime/metal/metal_module.mm
+++ b/src/runtime/metal/metal_module.mm
@@ -191,7 +191,9 @@ class MetalWrappedFunc {
AUTORELEASEPOOL {
metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal();
int device_id = t->device.device_id;
- auto stream = static_cast<metal::Stream*>(t->stream[device_id]);
+ // obtain the stream
+ auto stream =
+
metal::MetalWorkspace::Global()->CastStreamOrGetDefault(t->stream[device_id],
device_id);
if (stream->HasErrorHappened()) return;
if (scache_[device_id] == nil) {
scache_[device_id] = m_->GetPipelineState(device_id, func_name_);
@@ -265,10 +267,7 @@ Module MetalModuleCreate(std::unordered_map<std::string,
std::string> smap,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string fmt,
std::string source) {
ObjectPtr<Object> n;
- AUTORELEASEPOOL {
- metal::MetalWorkspace::Global()->Init();
- n = make_object<MetalModuleNode>(smap, fmap, fmt, source);
- };
+ AUTORELEASEPOOL { n = make_object<MetalModuleNode>(smap, fmap, fmt, source);
};
return Module(n);
}