This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0984e97a5c [Bugfix][NCCL] Release NCCL thread_local resources in 
destructor (#17078)
0984e97a5c is described below

commit 0984e97a5c799c7db961ffb2d427ee923eccb607
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Jun 12 18:44:51 2024 -0500

    [Bugfix][NCCL] Release NCCL thread_local resources in destructor (#17078)
    
    Prior to this commit, allocations performed by `ncclCommInitRank` had
    no corresponding call to `ncclCommDestroy`.  While `ncclCommDestroy`
    does occur in the `CCLThreadLocalContext::Clear` method, there are no
    calls into this method.  On worker processes, the failure to call
    `ncclCommDestroy` typically had little effect.  Any destruction would
    occur shortly before the process closes, and so resources would be
    reclaimed by the OS when the process terminates.
    
    However, worker0 of a Disco session is a separate thread, rather than
    a separate process.  While this allows it to easily receive data from
    the controller thread, resources allocated by worker0 are not
    reclaimed by the OS until the entire process terminates.  As a result,
    the `CCLThreadLocalContext` leaked GPU memory, as the
    `ncclCommInitRank` call at the start of each
    `tvm.runtime.disco.ProcessSession` was never de-allocated.  The
    increase in GPU memory usage was about 1 gigabyte for each
    `ProcessSession`.
    
    This commit updates `CCLThreadLocalContext` to have a destructor that
    calls the `Clear` method.  For worker0, this is called when the thread
    is joined to the main thread.
---
 src/runtime/disco/nccl/nccl.cc        | 12 ++++++++++++
 src/runtime/disco/nccl/nccl_context.h | 15 +++++++++++----
 2 files changed, 23 insertions(+), 4 deletions(-)

diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc
index 7b943cf83f..bba42ed3bd 100644
--- a/src/runtime/disco/nccl/nccl.cc
+++ b/src/runtime/disco/nccl/nccl.cc
@@ -67,9 +67,21 @@ void InitCCLPerWorker(IntTuple device_ids, std::string 
unique_id_bytes) {
   CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
   DiscoWorker* worker = DiscoWorker::ThreadLocal();
   ICHECK(worker != nullptr);
+
   CHECK_EQ(unique_id_bytes.size(), NCCL_UNIQUE_ID_BYTES)
       << "ValueError: The length of unique_id must be " << 
NCCL_UNIQUE_ID_BYTES << ", but got "
       << unique_id_bytes.size() << ".";
+
+  CHECK(!ctx->comm) << "Cannot initialize CCL, "
+                    << "the previous thread-global comm still exists, "
+                    << "and has not been destructed";
+  CHECK(!ctx->default_stream) << "Cannot initialize CCL, "
+                              << "the previous thread-global stream still 
exists, "
+                              << "and has not been destructed";
+  CHECK(!ctx->worker) << "Cannot initialize CCL, "
+                      << "the previous thread-global worker still exists, "
+                      << "and has not been destructed";
+
   // Step up local context of NCCL
   int device_id = device_ids[worker->worker_id];
   SetDevice(device_id);
diff --git a/src/runtime/disco/nccl/nccl_context.h 
b/src/runtime/disco/nccl/nccl_context.h
index 9d1b8b933a..3fb281f2cb 100644
--- a/src/runtime/disco/nccl/nccl_context.h
+++ b/src/runtime/disco/nccl/nccl_context.h
@@ -118,16 +118,23 @@ inline ncclDataType_t AsNCCLDataType(runtime::DataType 
dtype) {
 }
 
 struct CCLThreadLocalContext {
-  DiscoWorker* worker;
+  DiscoWorker* worker = nullptr;
   int device_id;
   deviceStream_t default_stream = nullptr;
-  ncclComm_t comm;
+  ncclComm_t comm = nullptr;
+
+  ~CCLThreadLocalContext() { Clear(); }
 
   void Clear() {
-    NCCL_CALL(ncclCommDestroy(comm));
-    if (default_stream != nullptr) {
+    if (comm) {
+      NCCL_CALL(ncclCommDestroy(comm));
+      comm = nullptr;
+    }
+    if (default_stream) {
       StreamDestroy(default_stream);
+      default_stream = nullptr;
     }
+    worker = nullptr;
   }
 
   deviceStream_t GetDefaultStream() {

Reply via email to