ptrendx commented on a change in pull request #20331:
URL: https://github.com/apache/incubator-mxnet/pull/20331#discussion_r701324371
##########
File path: src/engine/threaded_engine.cc
##########
@@ -523,5 +543,206 @@ void ThreadedEngine::OnCompleteStatic(Engine *engine,
void *opr_block_,
OprBlock::Delete(opr_block);
}
+void ThreadedEngine::OnStartStatic(Engine *engine, void *opr_block,
+ const dmlc::Error* error) {
+ // no-op
+}
+
+#if MXNET_USE_CUDA
+static inline void AddEventHelper(
+ std::unordered_map<cudaStream_t, EventInfo>* events_per_stream,
+ const EventInfo& cuda_event) {
+ auto event_stream = cuda_event.stream;
+ if (events_per_stream->count(event_stream) > 0) {
+ if ((*events_per_stream)[event_stream].pool_index < cuda_event.pool_index)
{
+ (*events_per_stream)[event_stream] = cuda_event;
+ }
+ } else {
+ (*events_per_stream).emplace(event_stream, cuda_event);
+ }
+}
+
+void ThreadedEngine::OnStartCPU(Engine *engine, void *opr_block,
+ const dmlc::Error* error) {
+ static bool use_new_dep_engine = dmlc::GetEnv("MXNET_ASYNC_GPU_ENGINE",
false);
+ if (!use_new_dep_engine) {
+ return;
+ }
+ ThreadedOpr *threaded_opr = static_cast<OprBlock*>(opr_block)->opr;
+ std::unordered_map<cudaStream_t, EventInfo> event_per_stream;
+ for (auto* read_var : threaded_opr->const_vars) {
+ auto &sync_obj = read_var->sync_object;
+ std::lock_guard<std::mutex> l(sync_obj.mutex);
+ auto &reader_events = sync_obj.reader_events;
+ // check for expired events and delete them
+ reader_events.erase(std::remove_if(reader_events.begin(),
reader_events.end(),
+ [&](const EventInfo e_i) {
+ return e_i.event.expired();
+ }), reader_events.end());
+ for (auto& cuda_event : reader_events) {
+ AddEventHelper(&event_per_stream, cuda_event);
+ }
+ if (!sync_obj.writer_event.empty()) {
+ if (sync_obj.writer_event[0].event.expired()) {
+ sync_obj.writer_event.clear();
+ } else {
+ AddEventHelper(&event_per_stream, sync_obj.writer_event[0]);
+ }
+ }
+ }
+
+ for (auto* write_var : threaded_opr->mutable_vars) {
+ auto &sync_obj = write_var->sync_object;
+ std::lock_guard<std::mutex> l(sync_obj.mutex);
+ auto &reader_events = sync_obj.reader_events;
+ // check for expired events and delete them
+ reader_events.erase(std::remove_if(reader_events.begin(),
reader_events.end(),
+ [&](const EventInfo e_i) {
+ return e_i.event.expired();
+ }), reader_events.end());
+ for (auto& cuda_event : reader_events) {
+ AddEventHelper(&event_per_stream, cuda_event);
+ }
+ if (!sync_obj.writer_event.empty()) {
+ if (sync_obj.writer_event[0].event.expired()) {
+ sync_obj.writer_event.clear();
+ } else {
+ AddEventHelper(&event_per_stream, sync_obj.writer_event[0]);
+ }
+ }
+ }
+ for (auto event : event_per_stream) {
+ auto ev = event.second.event.lock();
+ MSHADOW_CUDA_CALL(cudaEventSynchronize(*ev));
+ }
+}
+
+void ThreadedEngine::OnStartGPU(Engine *engine, void *sync_info,
+ const dmlc::Error* error) {
+ static bool use_new_dep_engine = dmlc::GetEnv("MXNET_ASYNC_GPU_ENGINE",
false);
+ if (!use_new_dep_engine) {
+ return;
+ }
+ auto *info = reinterpret_cast<GPUWorkerSyncInfo *>(sync_info);
+ CHECK(info->stream != nullptr);
+ auto *worker_stream = reinterpret_cast<mshadow::Stream<gpu> *>(info->stream);
+ ThreadedOpr *threaded_opr = static_cast<OprBlock*>(info->opr_block)->opr;
+ std::unordered_map<cudaStream_t, EventInfo> event_per_stream;
+ for (auto* read_var : threaded_opr->const_vars) {
+ auto &sync_obj = read_var->sync_object;
+ std::lock_guard<std::mutex> l(sync_obj.mutex);
+ auto &reader_events = sync_obj.reader_events;
+ // check for expired events and delete them
+ reader_events.erase(std::remove_if(reader_events.begin(),
reader_events.end(),
+ [&](const EventInfo e_i) {
+ return e_i.event.expired();
+ }), reader_events.end());
+ for (auto& writer : sync_obj.writer_event) {
+ if (writer.event.expired()) {
+ sync_obj.writer_event.clear();
+ break;
+ }
+ if (writer.stream != worker_stream->stream_) {
+ // if there is already a reader on the same stream as us,
+ // it already synced with that writer and we can rely on
+ // the ongoing sync
+ bool found = false;
+ for (const auto& reader : reader_events) {
+ if (reader.stream == worker_stream->stream_) {
+ found = true;
+ break;
+ }
+ }
+ if (!found) {
+ AddEventHelper(&event_per_stream,
+ writer);
+ }
+ }
+ }
+ }
+ for (auto* write_var : threaded_opr->mutable_vars) {
+ auto &sync_obj = write_var->sync_object;
+ std::lock_guard<std::mutex> l(sync_obj.mutex);
+ // check for expired events and delete them
+ auto &reader_events = sync_obj.reader_events;
+ reader_events.erase(std::remove_if(reader_events.begin(),
reader_events.end(),
+ [&](const EventInfo e_i) {
+ return e_i.event.expired();
+ }), reader_events.end());
+ // if there are some readers, we wait for them
+ for (auto& cuda_event : reader_events) {
+ if (worker_stream->stream_ != cuda_event.stream) {
+ AddEventHelper(&event_per_stream, cuda_event);
+ }
+ }
+ if (!sync_obj.writer_event.empty()) {
+ if (sync_obj.writer_event[0].event.expired()) {
+ sync_obj.writer_event.clear();
Review comment:
There can only be 1 writer event active (vs multiple readers).
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]