This is an automated email from the ASF dual-hosted git repository.
ptrendx pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 0b4ecdbc4a [BUGFIX] Fix threadsafety and shutdown issues with
threaded_engine_perdevice (#21110)
0b4ecdbc4a is described below
commit 0b4ecdbc4a35e50e097c6e171fcf0d78d46df007
Author: Dick Carter <[email protected]>
AuthorDate: Tue Aug 2 02:59:57 2022 -0700
[BUGFIX] Fix threadsafety and shutdown issues with
threaded_engine_perdevice (#21110)
* Fix threadsafety and shutdown issues with threaded_engine_perdevice
* Fix lint
* Add MXNET_USE_CUDA compile guards
* Remove unneeded include
---
src/engine/threaded_engine_perdevice.cc | 32 +++++++++++++++++++++-----------
1 file changed, 21 insertions(+), 11 deletions(-)
diff --git a/src/engine/threaded_engine_perdevice.cc
b/src/engine/threaded_engine_perdevice.cc
index 79e8eaa539..e4fe454639 100644
--- a/src/engine/threaded_engine_perdevice.cc
+++ b/src/engine/threaded_engine_perdevice.cc
@@ -28,6 +28,7 @@
#include <dmlc/concurrency.h>
#include <dmlc/thread_group.h>
+#include <mutex>
#include <memory>
#include "../initialize.h"
#include "./threaded_engine.h"
@@ -74,6 +75,10 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
gpu_copy_workers_.Clear();
cpu_normal_workers_.Clear();
cpu_priority_worker_.reset(nullptr);
+#if MXNET_USE_CUDA
+ streams_.clear();
+ cuda_event_pool_per_worker_.clear();
+#endif
}
void Stop() override {
@@ -278,6 +283,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
CHECK(block != nullptr);
mshadow::Stream<gpu>* stream = nullptr;
GPUAuxStream* aux_stream = nullptr;
+ CUDAEventPool* event_pool = nullptr;
do {
ThreadPool::SetReadyOnDestroy setReady(ready_event);
// allocate stream
@@ -288,18 +294,22 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
stream = mshadow::NewStream<gpu>(true, MXNET_USE_CUDNN != 0,
ctx.dev_id);
aux_stream = new GPUAuxStream(stream);
}
+ // With thread safety...
+ {
+ static std::mutex m;
+ std::lock_guard<std::mutex> lock(m);
+ // register stream
+ streams_.push_back(stream);
+ auto event_pool_it = cuda_event_pool_per_worker_.find(ctx.dev_id);
+ if (event_pool_it != cuda_event_pool_per_worker_.end()) {
+ event_pool = event_pool_it->second.get();
+ } else {
+ auto res =
+ cuda_event_pool_per_worker_.emplace(ctx.dev_id,
std::make_unique<CUDAEventPool>(ctx));
+ event_pool = res.first->second.get();
+ }
+ }
} while (false);
- // register stream
- streams_.push_back(stream);
- CUDAEventPool* event_pool;
- auto event_pool_it = cuda_event_pool_per_worker_.find(ctx.dev_id);
- if (event_pool_it != cuda_event_pool_per_worker_.end()) {
- event_pool = event_pool_it->second.get();
- } else {
- auto res =
- cuda_event_pool_per_worker_.emplace(ctx.dev_id,
std::make_unique<CUDAEventPool>(ctx));
- event_pool = res.first->second.get();
- }
// execute task
OprBlock* opr_block;
RunContext run_ctx{ctx, stream, aux_stream};