This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 40b9a926b9 [Disco] Pipe-based Multi-processing Session (#15727)
40b9a926b9 is described below

commit 40b9a926b9df375a6ca5ce51ea93e4d849b85517
Author: Junru Shao <[email protected]>
AuthorDate: Thu Sep 14 04:43:44 2023 -0700

    [Disco] Pipe-based Multi-processing Session (#15727)
    
    This PR introduces `ProcessSession`, a new session implementation based
    on multi-processing.
    
    `ProcessSession` shares exactly the same communication protocol with
    `ThreadedSession`, but all workers except for worker 0 are launched in a
    separate process than thread. Workers communicate with the controller
    via pipe provided by the OS, rather than SPSC message queue between
    threads.
    
    In our implementation, Python's `subproces.popen` is used to create
    subprocesses, and the Python executable, or more specifically,
    `sys.executable` calls into `tvm.exec.disco_worker` as the entrypoint.
    Besides the launching logic that is only executed once in the very
    beginning, the rest of the implementation resides in a C++-only
    environment, including reads/writes to pipe file descriptors,
    serialization and deserialization of messages, worker interpretation of
    each message, etc.
    
    Detailed engineering elements included in this PR:
    - Refactors the MinRPC-based communication protocol out to be shared by
      `ProcessSession` and `ThreadedSession` as `protocol.h`;
    - Refactors a controller-side worker thread into `DiscoWorkerThread`,
      which is shared by both session implementation to launch worker-0;
    - Added two instructions `kDebugGetFromRemote` and `kDebugSetRegister`,
      which are used to communicate with workers other than worker-0 in
      debug mode;
    - Introduces multi-processing infra including: `tvm.exec.disco_worker`
      serving as the entrypoint that launches workers, and
      `tvm/runtime/disco/process_pool.py` that exposes APIs to launch worker
      processes. `tvm.exec.disco_worker` calls into a global function
      `runtime.disco.WorkerProcess` that executes the worker main loop in
      pure C++;
    - Introduces `src/support/process_id.h` that provides cross-platform pid
      and tid printing utilities;
    - Refactors Disco's NCCL integration that get rids of initialized-once
      global NCCL context, and switches to broadcasting `ncclUniqueId` from
      controller to all workers, and then create NCCL communicators in each
      worker thread/process accordingly. This is a thread/process-agnostic
      way of using NCCL.
---
 include/tvm/runtime/disco/session.h      |  50 +++++-
 python/tvm/exec/disco_worker.py          |  51 +++++++
 python/tvm/runtime/disco/__init__.py     |   9 +-
 python/tvm/runtime/disco/process_pool.py | 180 ++++++++++++++++++++++
 python/tvm/runtime/disco/session.py      |  27 +++-
 python/tvm/testing/__init__.py           |   3 +-
 python/tvm/testing/disco.py              |  53 +++++++
 src/runtime/disco/bcast_session.cc       |   7 +
 src/runtime/disco/bcast_session.h        |   2 +
 src/runtime/disco/builtin.cc             |   6 +-
 src/runtime/disco/nccl/nccl.cc           |  91 ++++-------
 src/runtime/disco/nccl/utils.h           |   2 +
 src/runtime/disco/process_session.cc     | 213 ++++++++++++++++++++++++++
 src/runtime/disco/protocol.h             | 254 +++++++++++++++++++++++++++++++
 src/runtime/disco/session.cc             |  11 +-
 src/runtime/disco/threaded_session.cc    | 128 ++++------------
 src/runtime/disco/worker.cc              |  55 ++++++-
 src/runtime/disco/worker.h               |  43 ++++++
 src/support/process_id.h                 |  67 ++++++++
 tests/python/disco/test_nccl.py          |  92 ++++++-----
 tests/python/disco/test_session.py       |  95 +++++-------
 21 files changed, 1149 insertions(+), 290 deletions(-)

diff --git a/include/tvm/runtime/disco/session.h 
b/include/tvm/runtime/disco/session.h
index e28fb7144c..984ea026d8 100644
--- a/include/tvm/runtime/disco/session.h
+++ b/include/tvm/runtime/disco/session.h
@@ -72,6 +72,7 @@
 #ifndef TVM_RUNTIME_DISCO_SESSION_H_
 #define TVM_RUNTIME_DISCO_SESSION_H_
 
+#include <tvm/runtime/container/shape_tuple.h>
 #include <tvm/runtime/object.h>
 #include <tvm/runtime/packed_func.h>
 
@@ -92,6 +93,8 @@ enum class DiscoAction : int32_t {
   kSyncWorker = 4,
   kCopyFromWorker0 = 5,
   kCopyToWorker0 = 6,
+  kDebugGetFromRemote = 7,
+  kDebugSetRegister = 8,
 };
 
 /*! \brief Converts the enum class `DiscoAction` to string */
@@ -111,6 +114,10 @@ inline std::string DiscoAction2String(DiscoAction action) {
       return "kCopyFromWorker0";
     case DiscoAction::kCopyToWorker0:
       return "kCopyToWorker0";
+    case DiscoAction::kDebugGetFromRemote:
+      return "kDebugGetFromRemote";
+    case DiscoAction::kDebugSetRegister:
+      return "kDebugSetRegister";
   }
   LOG(FATAL) << "ValueError: Unknown DiscoAction: " << 
static_cast<int>(action);
 }
@@ -136,7 +143,7 @@ class DRefObj : public Object {
    * \param worker_id The id of the worker to be copied to.
    * \param source The NDArray to be copied.
    */
-  void DebugCopyFrom(int worker_id, NDArray source);
+  inline void DebugCopyFrom(int worker_id, TVMArgValue source);
 
   static constexpr const char* _type_key = "runtime.disco.DRef";
   static constexpr const uint32_t _type_index = TypeIndex::kRuntimeDiscoDRef;
@@ -213,6 +220,12 @@ class SessionObj : public Object {
   virtual void SyncWorker(int worker_id) = 0;
   /*! \brief Signal all the workers to shutdown */
   virtual void Shutdown() = 0;
+  /*!
+   * \brief Initialize the data plane between workers.
+   * \param ccl The name of the communication backend, e.g., nccl, rccl, mpi.
+   * \param device_ids The device ids of the workers.
+   */
+  virtual void InitCCL(String ccl, ShapeTuple device_ids) = 0;
   /*!
    * \brief Get the value of a register from a remote worker.
    * \param reg_id The id of the register to be fetched.
@@ -220,13 +233,19 @@ class SessionObj : public Object {
    * \return The value of the register.
    */
   virtual TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) = 0;
-
-  static constexpr const char* _type_key = "runtime.disco.Session";
-  TVM_DECLARE_BASE_OBJECT_INFO(SessionObj, Object);
+  /*!
+   * \brief Set the value of a register on a remote worker.
+   * \param reg_id The id of the register to be set.
+   * \param value The value to be set.
+   * \param worker_id The id of the worker to be set.
+   */
+  virtual void DebugSetRegister(int64_t reg_id, TVMArgValue value, int 
worker_id) = 0;
 
   struct FFI;
   friend struct SessionObj::FFI;
   friend class DRefObj;
+  static constexpr const char* _type_key = "runtime.disco.Session";
+  TVM_DECLARE_BASE_OBJECT_INFO(SessionObj, Object);
 
  protected:
   /*! \brief Deallocate a register id, kill it on all workers, and append it 
to `free_regs_`. */
@@ -239,8 +258,22 @@ class SessionObj : public Object {
  */
 class Session : public ObjectRef {
  public:
-  /*! \brief Create a session backed by a thread pool of workers */
-  static Session ThreadedSession(int num_workers);
+  /*!
+   * \brief Create a session backed by a thread pool of workers
+   * \param num_workers The number of workers.
+   */
+  TVM_DLL static Session ThreadedSession(int num_workers);
+  /*!
+   * \brief Create a session backed by pipe-based multiprocessing
+   * \param num_workers The number of workers.
+   * \param process_pool_creator The name of a global function that takes 
`num_workers` as an input,
+   * and returns a PackedFunc, which takes an integer `worker_id` as the input 
and returns None.
+   * When `worker-id` is 0, it shuts down the process pool; Otherwise, it 
retursn a tuple
+   * (read_fd, writefd) used to communicate with the corresponding worker.
+   * \note Worker-0 is always co-located with the controler as a separate 
thread, and therefore
+   * worker-0 does not exist in the process pool.
+   */
+  TVM_DLL static Session ProcessSession(int num_workers, String 
process_pool_creator);
   TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Session, ObjectRef, 
SessionObj);
 };
 
@@ -250,6 +283,7 @@ class Session : public ObjectRef {
  */
 class DiscoChannel {
  public:
+  virtual ~DiscoChannel() = default;
   /*! \brief Send a packed sequence to the receiver */
   virtual void Send(const TVMArgs& args) = 0;
   /*! \brief Receive a packed sequence from worker */
@@ -272,6 +306,10 @@ TVMRetValue DRefObj::DebugGetFromRemote(int worker_id) {
   return Downcast<Session>(this->session)->DebugGetFromRemote(this->reg_id, 
worker_id);
 }
 
+void DRefObj::DebugCopyFrom(int worker_id, TVMArgValue value) {
+  return Downcast<Session>(this->session)->DebugSetRegister(this->reg_id, 
value, worker_id);
+}
+
 template <typename... Args>
 DRef SessionObj::CallPacked(const DRef& func, Args&&... args) {
   constexpr int offset = 3;
diff --git a/python/tvm/exec/disco_worker.py b/python/tvm/exec/disco_worker.py
new file mode 100644
index 0000000000..9faa5742ae
--- /dev/null
+++ b/python/tvm/exec/disco_worker.py
@@ -0,0 +1,51 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Internal DiscoWorker for Disco ProcessSession."""
+import os
+import sys
+
+from tvm import runtime as _  # pylint: disable=unused-import
+from tvm._ffi import get_global_func
+from tvm.testing import disco as _  # pylint: disable=unused-import
+
+
+def main():
+    """Main worker function"""
+    if len(sys.argv) != 5:
+        print("Usage: <worker_id> <num_workers> <read_fd> <write_fd>")
+        return
+    worker_id = int(sys.argv[1])
+    num_workers = int(sys.argv[2])
+    if sys.platform == "win32":
+        import msvcrt  # pylint: disable=import-outside-toplevel,import-error
+
+        reader = msvcrt.open_osfhandle(int(sys.argv[3]), os.O_BINARY)
+        writer = msvcrt.open_osfhandle(int(sys.argv[4]), os.O_BINARY)
+    else:
+        reader = int(sys.argv[3])
+        writer = int(sys.argv[4])
+
+    worker_func = get_global_func("runtime.disco.WorkerProcess")
+    worker_func(worker_id, num_workers, reader, writer)
+
+
+if __name__ == "__main__":
+    try:
+        main()
+    except (KeyboardInterrupt, IOError):
+        pass
diff --git a/python/tvm/runtime/disco/__init__.py 
b/python/tvm/runtime/disco/__init__.py
index 57c0548e2e..856e69bc35 100644
--- a/python/tvm/runtime/disco/__init__.py
+++ b/python/tvm/runtime/disco/__init__.py
@@ -15,4 +15,11 @@
 # specific language governing permissions and limitations
 # under the License.
 """TVM distributed runtime API."""
-from .session import DModule, DPackedFunc, DRef, Session, ThreadedSession
+from .session import (
+    DModule,
+    DPackedFunc,
+    DRef,
+    ProcessSession,
+    Session,
+    ThreadedSession,
+)
diff --git a/python/tvm/runtime/disco/process_pool.py 
b/python/tvm/runtime/disco/process_pool.py
new file mode 100644
index 0000000000..44348577f7
--- /dev/null
+++ b/python/tvm/runtime/disco/process_pool.py
@@ -0,0 +1,180 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Pipe worker for multi-processing."""
+import os
+import subprocess
+import sys
+
+import psutil
+
+from tvm._ffi import register_func
+from tvm.runtime import ShapeTuple
+
+
+class DiscoPopenWorker:
+    """A subprocess worker via Popen.
+
+    PopenWorker provides a low-level
+    API to interact with a separate process via Popen.
+
+    Parameters
+    ----------
+    worker_id : int
+        The worker id of the current worker.
+
+    num_workers : int
+        The total number of workers.
+
+    stdout: Union[None, int, IO[Any]]
+        The standard output streams handler specified for the popen process.
+
+    stderr: Union[None, int, IO[Any]]
+        The standard error streams handler specified for the popen process.
+    """
+
+    def __init__(self, worker_id: int, num_workers: int, stdout=None, 
stderr=None):
+        self.worker_id = worker_id
+        self.num_workers = num_workers
+        self._proc = None
+        self._stdout = stdout
+        self._stderr = stderr
+
+    def __del__(self):
+        try:
+            self.kill()
+        except ImportError:
+            pass
+
+    def kill(self):
+        """Kill the current running process and cleanup.
+
+        Note
+        ----
+        The worker can start a new process when send is called again.
+        """
+        if self._proc is not None:
+            # kill all child processes recursively
+            try:
+                _kill_child_processes(self._proc.pid)
+            except TypeError:
+                pass
+            try:
+                self._proc.kill()
+            except OSError:
+                pass
+
+            # Join the child process to avoid zombie processes
+            self.join(timeout=1.0)
+            self._proc = None
+
+    def join(self, timeout=None):
+        """Join the current process worker before it terminates.
+
+        Parameters
+        ----------
+        timeout: Optional[number]
+            Timeout value, block at most timeout seconds if it
+            is a positive number.
+        """
+        if self._proc:
+            try:
+                self._proc.wait(timeout)
+            except subprocess.TimeoutExpired:
+                pass
+
+    def start(self):
+        """Start a new subprocess if nothing is available"""
+        if self._proc is not None:
+            return None, None
+
+        # connect subprocess with a pair of pipes
+        main_read, worker_write = os.pipe()
+        worker_read, main_write = os.pipe()
+
+        cmd = [
+            sys.executable,
+            "-m",
+            "tvm.exec.disco_worker",
+            str(self.worker_id),
+            str(self.num_workers),
+        ]
+        if sys.platform == "win32":
+            import msvcrt  # pylint: 
disable=import-error,import-outside-toplevel
+
+            worker_read_handle = msvcrt.get_osfhandle(worker_read)
+            worker_write_handle = msvcrt.get_osfhandle(worker_write)
+            os.set_handle_inheritable(worker_read_handle, True)
+            os.set_handle_inheritable(worker_write_handle, True)
+            cmd += [str(worker_read_handle), str(worker_write_handle)]
+            self._proc = subprocess.Popen(
+                cmd,
+                close_fds=False,
+                stdout=self._stdout,
+                stderr=self._stderr,
+            )
+        else:
+            cmd += [str(worker_read), str(worker_write)]
+            self._proc = subprocess.Popen(  # pylint: 
disable=consider-using-with
+                cmd,
+                pass_fds=(worker_read, worker_write),
+                stdout=self._stdout,
+                stderr=self._stderr,
+            )
+
+        # close worker side of the pipe
+        os.close(worker_read)
+        os.close(worker_write)
+        return main_read, main_write
+
+
+def _kill_child_processes(pid):
+    """Kill all child processes recursively for a given pid.
+
+    Parameters
+    ----------
+    pid : int
+        The given parameter id.
+    """
+    try:
+        parent = psutil.Process(pid)
+        children = parent.children(recursive=True)
+    except psutil.NoSuchProcess:
+        return
+
+    for process in children:
+        try:
+            process.kill()
+        except psutil.NoSuchProcess:
+            pass
+
+
+@register_func("runtime.disco.create_process_pool")
+def _create_process_pool(num_workers: int):
+    """Create a process pool where the workers' are are [1, num_workers)."""
+    pool = [DiscoPopenWorker(i, num_workers) for i in range(1, num_workers)]
+
+    def result_func(worker_id: int):
+        nonlocal pool
+        if worker_id != 0:
+            read_fd, write_fd = pool[worker_id - 1].start()
+            return ShapeTuple([read_fd, write_fd])
+        print("Shutting down the process pool")
+        del pool
+        return None
+
+    return result_func
diff --git a/python/tvm/runtime/disco/session.py 
b/python/tvm/runtime/disco/session.py
index eab5a5268d..d05561c2d1 100644
--- a/python/tvm/runtime/disco/session.py
+++ b/python/tvm/runtime/disco/session.py
@@ -27,7 +27,7 @@ from ..container import ShapeTuple
 from ..ndarray import NDArray
 from ..ndarray import array as _as_NDArray
 from ..object import Object
-from . import _ffi_api
+from . import _ffi_api, process_pool  # pylint: disable=unused-import
 
 
 @register_object("runtime.disco.DRef")
@@ -250,22 +250,21 @@ class Session(Object):
         func = self._get_cached_method("runtime.disco.load_vm_module")
         return DModule(func(path, device))
 
-    def init_ccl(self, api: str, *args):
+    def init_ccl(self, ccl: str, *device_ids):
         """Initialize the underlying communication collective library.
 
         Parameters
         ----------
-        api : str
+        ccl : str
             The name of the communication collective library. Currently 
supported libraries are:
             - nccl
             - rccl
             - mpi
-        *args : various types
-            The arguments to be passed to the initialization function of the 
communication
+        *device_ids : int
+            The device IDs to be used by the underlying communication library.
         """
-        assert api in ("nccl", "rccl"), f"Unsupported CCL backend: {api}"
-        func = self.get_global_func(f"runtime.disco.{api}.init_ccl")
-        func(*args)
+        assert ccl in ("nccl", "rccl"), f"Unsupported CCL backend: {ccl}"
+        return _ffi_api.SessionInitCCL(self, ccl, ShapeTuple(device_ids))  # 
type: ignore # pylint: disable=no-member
 
     def broadcast_from_worker0(self, src: DRef, dst: DRef) -> DRef:
         """Broadcast an array from worker-0 to all other workers.
@@ -343,6 +342,18 @@ class ThreadedSession(Session):
         )
 
 
+@register_object("runtime.disco.ProcessSession")
+class ProcessSession(Session):
+    """A Disco session backed by pipe-based multi-processing."""
+
+    def __init__(self, num_workers: int) -> None:
+        self.__init_handle_by_constructor__(
+            _ffi_api.SessionProcess,  # type: ignore # pylint: 
disable=no-member
+            num_workers,
+            "runtime.disco.create_process_pool",
+        )
+
+
 REDUCE_OPS = {
     "sum": 0,
     "prod": 1,
diff --git a/python/tvm/testing/__init__.py b/python/tvm/testing/__init__.py
index 3e5f838a27..9aa1a31933 100644
--- a/python/tvm/testing/__init__.py
+++ b/python/tvm/testing/__init__.py
@@ -17,8 +17,7 @@
 
 # pylint: disable=redefined-builtin, wildcard-import
 """Utility Python functions for TVM testing"""
-
-from . import auto_scheduler, autotvm
+from . import auto_scheduler, autotvm, disco
 from ._ffi_api import (
     ErrorTest,
     FrontendTestModule,
diff --git a/python/tvm/testing/disco.py b/python/tvm/testing/disco.py
new file mode 100644
index 0000000000..c13e83b7c4
--- /dev/null
+++ b/python/tvm/testing/disco.py
@@ -0,0 +1,53 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, missing-function-docstring, 
missing-class-docstring
+"""Common utilities for testing disco"""
+from tvm._ffi import register_func
+from tvm.runtime import NDArray, ShapeTuple, String
+from tvm.runtime.ndarray import array
+
+
+@register_func("tests.disco.add_one")
+def add_one(x: int) -> int:  # pylint: disable=invalid-name
+    return x + 1
+
+
+@register_func("tests.disco.add_one_float", override=True)
+def add_one_float(x: float):  # pylint: disable=invalid-name
+    return x + 0.5
+
+
+@register_func("tests.disco.add_one_ndarray", override=True)
+def add_one_ndarray(x: NDArray) -> NDArray:  # pylint: disable=invalid-name
+    return array(x.numpy() + 1)
+
+
+@register_func("tests.disco.str", override=True)
+def str_func(x: str):  # pylint: disable=invalid-name
+    return x + "_suffix"
+
+
+@register_func("tests.disco.str_obj", override=True)
+def str_obj_func(x: String):  # pylint: disable=invalid-name
+    assert isinstance(x, String)
+    return String(x + "_suffix")
+
+
+@register_func("tests.disco.shape_tuple", override=True)
+def shape_tuple_func(x: ShapeTuple):  # pylint: disable=invalid-name
+    assert isinstance(x, ShapeTuple)
+    return ShapeTuple(list(x) + [4, 5])
diff --git a/src/runtime/disco/bcast_session.cc 
b/src/runtime/disco/bcast_session.cc
index 0625c1157e..9b553c319c 100644
--- a/src/runtime/disco/bcast_session.cc
+++ b/src/runtime/disco/bcast_session.cc
@@ -68,6 +68,13 @@ void BcastSessionObj::Shutdown() {
   BcastSessionObj::Internal::BroadcastUnpacked(this, DiscoAction::kShutDown, 
0);
 }
 
+void BcastSessionObj::InitCCL(String ccl, ShapeTuple device_ids) {
+  const auto* pf = runtime::Registry::Get("runtime.disco." + ccl + 
".init_ccl");
+  CHECK(pf) << "ValueError: Cannot initialize CCL `" << ccl
+            << "`, because cannot find function: runtime.disco." << ccl << 
".init_ccl";
+  (*pf)(GetRef<Session>(this), device_ids);
+}
+
 void BcastSessionObj::SyncWorker(int worker_id) {
   BcastSessionObj::Internal::BroadcastUnpacked(this, DiscoAction::kSyncWorker, 
worker_id);
   TVMArgs args = this->RecvReplyPacked(worker_id);
diff --git a/src/runtime/disco/bcast_session.h 
b/src/runtime/disco/bcast_session.h
index 0221207b96..d064b30f5a 100644
--- a/src/runtime/disco/bcast_session.h
+++ b/src/runtime/disco/bcast_session.h
@@ -42,7 +42,9 @@ class BcastSessionObj : public SessionObj {
   void CopyToWorker0(const NDArray& host_array, const DRef& remote_array) 
override;
   void SyncWorker(int worker_id) override;
   void Shutdown() override;
+  void InitCCL(String ccl, ShapeTuple device_ids) override;
   TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) override = 0;
+  void DebugSetRegister(int64_t reg_id, TVMArgValue value, int worker_id) 
override = 0;
 
  protected:
   /*! \brief Deallocate a register id, kill it on all workers, and append it 
to `free_regs_`. */
diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc
index 64e3fd4b28..06408c723a 100644
--- a/src/runtime/disco/builtin.cc
+++ b/src/runtime/disco/builtin.cc
@@ -100,7 +100,11 @@ void RecvFromWorker0(NDArray buffer) { 
GetCCLFunc("recv_from_worker0")(buffer);
 
 int WorkerId() { return DiscoWorker::ThreadLocal()->worker_id; }
 
-void SyncWorker() { GetCCLFunc("sync_worker")(); }
+void SyncWorker() {
+  if (DiscoWorker::ThreadLocal()->ccl != "") {
+    GetCCLFunc("sync_worker")();
+  }
+}
 
 
TVM_REGISTER_GLOBAL("runtime.disco.load_vm_module").set_body_typed(LoadVMModule);
 TVM_REGISTER_GLOBAL("runtime.disco.empty").set_body_typed(DiscoEmptyNDArray);
diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc
index 0212923cef..e404e3c2bb 100644
--- a/src/runtime/disco/nccl/nccl.cc
+++ b/src/runtime/disco/nccl/nccl.cc
@@ -19,13 +19,16 @@
 #include <cuda_runtime_api.h>
 #include <dlpack/dlpack.h>
 #include <nccl.h>
+#include <tvm/runtime/c_runtime_api.h>
 #include <tvm/runtime/disco/session.h>
 #include <tvm/runtime/registry.h>
 
+#include <cstring>
 #include <mutex>
 #include <sstream>
 #include <vector>
 
+#include "../../../support/process_id.h"
 #include "../../cuda/cuda_common.h"
 #include "./utils.h"
 
@@ -33,48 +36,6 @@ namespace tvm {
 namespace runtime {
 namespace nccl {
 
-struct NCCLGlobalContext {
-  std::vector<ncclComm_t> communicators;
-
-  static NCCLGlobalContext* Get() {
-    static NCCLGlobalContext ctx;
-    return &ctx;
-  }
-
-  void Initialize(const std::vector<int>& device_ids) {
-    {
-      std::ostringstream os;
-      bool is_first = true;
-      for (int device_id : device_ids) {
-        if (!is_first) {
-          os << ",";
-        } else {
-          is_first = false;
-        }
-        os << device_id;
-      }
-      LOG(INFO) << "Initializing NCCL with devices: " << os.str() << ".";
-    }
-    // TODO(@junrushao): support more flexible communicator pattern for 
generic SPMD usecases
-    DiscoWorker* worker = DiscoWorker::ThreadLocal();
-    int num_workers = worker->num_workers;
-    CHECK_EQ(device_ids.size(), num_workers)
-        << "ValueError: There are " << num_workers << " worker(s), but " << 
device_ids.size()
-        << " device id(s) are provided.";
-    ncclUniqueId id;
-    NCCL_CALL(ncclGetUniqueId(&id));
-    NCCL_CALL(ncclGroupStart());
-    for (int worker_id = 0; worker_id < num_workers; ++worker_id) {
-      int device_id = device_ids[worker_id];
-      ncclComm_t comm;
-      CUDA_CALL(cudaSetDevice(device_id));
-      NCCL_CALL(ncclCommInitRank(&comm, num_workers, id, worker_id));
-      this->communicators.push_back(comm);
-    }
-    NCCL_CALL(ncclGroupEnd());
-  }
-};
-
 struct NCCLThreadLocalContext {
   DiscoWorker* worker;
   int device_id;
@@ -92,23 +53,38 @@ struct NCCLThreadLocalContext {
   }
 };
 
-void InitCCL(const std::vector<int>& device_ids) {
-  // Set up global context only once
-  static std::once_flag flag;
-  std::call_once(flag, [&]() { 
NCCLGlobalContext::Get()->Initialize(device_ids); });
-  // Set up thread-local context for each thread
-  DiscoWorker* worker = DiscoWorker::ThreadLocal();
+void InitCCL(Session sess, ShapeTuple device_ids) {
+  DRef func = sess->GetGlobalFunc("runtime.disco.nccl.init_ccl_per_worker");
+  LOG(INFO) << "Initializing NCCL with devices: " << device_ids;
+  ncclUniqueId id;
+  TVMByteArray array;
+  NCCL_CALL(ncclGetUniqueId(&id));
+  array.data = id.internal;
+  array.size = NCCL_UNIQUE_ID_BYTES;
+  sess->CallPacked(func, device_ids, array);
+}
+
+void InitCCLPerWorker(ShapeTuple device_ids, std::string unique_id_bytes) {
   NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
+  DiscoWorker* worker = DiscoWorker::ThreadLocal();
+  ICHECK(worker != nullptr);
+  CHECK_EQ(unique_id_bytes.size(), NCCL_UNIQUE_ID_BYTES)
+      << "ValueError: The length of unique_id must be " << 
NCCL_UNIQUE_ID_BYTES << ", but got "
+      << unique_id_bytes.size() << ".";
+  // Step up local context of NCCL
   int device_id = device_ids[worker->worker_id];
   CUDA_CALL(cudaSetDevice(device_id));
+  CUDA_CALL(cudaStreamCreate(&ctx->stream));
   Device device{DLDeviceType::kDLCUDA, device_id};
+  DeviceAPI::Get(device)->SetStream(device, ctx->stream);
   worker->default_device = device;
   worker->ccl = "nccl";
   ctx->worker = worker;
   ctx->device_id = device_id;
-  ctx->comm = NCCLGlobalContext::Get()->communicators[worker->worker_id];
-  CUDA_CALL(cudaStreamCreate(&ctx->stream));
-  DeviceAPI::Get(device)->SetStream(device, ctx->stream);
+  // Initialize the communicator
+  ncclUniqueId id;
+  std::memcpy(id.internal, unique_id_bytes.data(), NCCL_UNIQUE_ID_BYTES);
+  NCCL_CALL(ncclCommInitRank(&ctx->comm, worker->num_workers, id, 
worker->worker_id));
 }
 
 void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv) {
@@ -158,7 +134,7 @@ void ScatterFromWorker0(Optional<NDArray> send, NDArray 
recv) {
     }
   } else {
     if (send.defined()) {
-      LOG(WARNING) << "ValueError: buffer `send` must be None when worker_id 
!= 0. However, got "
+      LOG(WARNING) << "Buffer `send` must be None when worker_id != 0, but got 
"
                       "send = "
                    << send.get() << ". This will be ignored.";
     }
@@ -222,17 +198,12 @@ void RecvFromWorker0(NDArray buffer) {
 
 void SyncWorker() {
   NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
+  ICHECK(ctx->worker != nullptr);
   CUDA_CALL(cudaStreamSynchronize(ctx->stream));
 }
 
-TVM_REGISTER_GLOBAL("runtime.disco.nccl.init_ccl")
-    .set_body([](TVMArgs args, TVMRetValue* rv) -> void {
-      std::vector<int> device_ids;
-      for (int i = 0; i < args.num_args; ++i) {
-        device_ids.push_back(args[i].operator int());
-      }
-      InitCCL(device_ids);
-    });
+TVM_REGISTER_GLOBAL("runtime.disco.nccl.init_ccl").set_body_typed(InitCCL);
+TVM_REGISTER_GLOBAL("runtime.disco.nccl.init_ccl_per_worker").set_body_typed(InitCCLPerWorker);
 TVM_REGISTER_GLOBAL("runtime.disco.nccl.allreduce")
     .set_body_typed([](NDArray send, int kind, NDArray recv) {
       CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " << 
kind;
diff --git a/src/runtime/disco/nccl/utils.h b/src/runtime/disco/nccl/utils.h
index 4e5fb8cd74..7f40365136 100644
--- a/src/runtime/disco/nccl/utils.h
+++ b/src/runtime/disco/nccl/utils.h
@@ -69,6 +69,7 @@ inline ncclDataType_t AsNCCLDataType(runtime::DataType dtype) 
{
     return ncclBfloat16;
   }
   LOG(FATAL) << "ValueError: Unsupported data type " << dtype;
+  throw;
 }
 
 inline ncclRedOp_t AsNCCLRedOp(ReduceKind kind) {
@@ -85,6 +86,7 @@ inline ncclRedOp_t AsNCCLRedOp(ReduceKind kind) {
       return ncclAvg;
   }
   LOG(FATAL) << "ValueError: Unknown ReduceKind: " << static_cast<int>(kind);
+  throw;
 }
 
 }  // namespace nccl
diff --git a/src/runtime/disco/process_session.cc 
b/src/runtime/disco/process_session.cc
new file mode 100644
index 0000000000..8ddfdce812
--- /dev/null
+++ b/src/runtime/disco/process_session.cc
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/runtime/object.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include <memory>
+#include <sstream>
+#include <utility>
+#include <vector>
+
+#include "../../support/pipe.h"
+#include "../minrpc/rpc_reference.h"
+#include "./bcast_session.h"
+#include "./protocol.h"
+#include "./worker.h"
+#include "tvm/runtime/c_runtime_api.h"
+
+namespace tvm {
+namespace runtime {
+
+class DiscoPipeMessageQueue : private ::tvm::support::Pipe,
+                              private DiscoProtocol<DiscoPipeMessageQueue> {
+ public:
+  explicit DiscoPipeMessageQueue(int64_t handle) : 
::tvm::support::Pipe(handle) {}
+
+  ~DiscoPipeMessageQueue() = default;
+
+  void Send(const TVMArgs& args) {
+    RPCReference::ReturnPackedSeq(args.values, args.type_codes, args.num_args, 
this);
+  }
+
+  TVMArgs Recv() {
+    {
+      this->RecycleAll();
+      uint64_t packet_nbytes = 0;
+      RPCCode code = RPCCode::kReturn;
+      this->Read(&packet_nbytes);
+      this->Read(&code);
+    }
+    TVMValue* values = nullptr;
+    int* type_codes = nullptr;
+    int num_args = 0;
+    RPCReference::RecvPackedSeq(&values, &type_codes, &num_args, this);
+    return TVMArgs(values, type_codes, num_args);
+  }
+
+  using dmlc::Stream::Read;
+  using dmlc::Stream::ReadArray;
+  using dmlc::Stream::Write;
+  using dmlc::Stream::WriteArray;
+  friend struct RPCReference;
+  friend struct DiscoProtocol<DiscoPipeMessageQueue>;
+};
+
+class DiscoProcessChannel final : public DiscoChannel {
+ public:
+  DiscoProcessChannel(int64_t controler_to_worker_fd, int64_t 
worker_to_controler_fd)
+      : controler_to_worker_(controler_to_worker_fd),
+        worker_to_controler_(worker_to_controler_fd) {}
+
+  DiscoProcessChannel(DiscoProcessChannel&& other) = delete;
+  DiscoProcessChannel(const DiscoProcessChannel& other) = delete;
+
+  void Send(const TVMArgs& args) { controler_to_worker_.Send(args); }
+  TVMArgs Recv() { return controler_to_worker_.Recv(); }
+  void Reply(const TVMArgs& args) { worker_to_controler_.Send(args); }
+  TVMArgs RecvReply() { return worker_to_controler_.Recv(); }
+
+  DiscoPipeMessageQueue controler_to_worker_;
+  DiscoPipeMessageQueue worker_to_controler_;
+};
+
+class ProcessSessionObj final : public BcastSessionObj {
+ public:
+  explicit ProcessSessionObj(int num_workers, PackedFunc process_pool)
+      : process_pool_(process_pool),
+        worker_0_(std::make_unique<DiscoWorkerThread>(0, num_workers, 
&worker_zero_data_)) {
+    std::vector<int64_t> read_fds;
+    std::vector<int64_t> write_fds;
+    read_fds.reserve(num_workers - 1);
+    write_fds.reserve(num_workers - 1);
+    for (int i = 1; i < num_workers; ++i) {
+      ShapeTuple fds = process_pool(i);
+      CHECK_EQ(fds.size(), 2) << "ValueError: process_pool(" << i << ") should 
return a tuple of "
+                              << "size 2, but got a tuple of size " << 
fds.size() << ".";
+      read_fds.push_back(fds[0]);
+      write_fds.push_back(fds[1]);
+    }
+    for (int i = 0; i < num_workers - 1; ++i) {
+      
workers_.emplace_back(std::make_unique<DiscoProcessChannel>(write_fds[i], 
read_fds[i]));
+    }
+  }
+
+  void Kill() {
+    if (this->worker_0_ != nullptr) {
+      this->Shutdown();
+      this->worker_0_.reset();
+      this->workers_.clear();
+      this->process_pool_(0);
+    }
+  }
+
+  ~ProcessSessionObj() { Kill(); }
+
+  TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) {
+    if (worker_id == 0) {
+      this->SyncWorker(worker_id);
+      return worker_0_->worker->register_file.at(reg_id);
+    }
+    {
+      TVMValue values[3];
+      int type_codes[3];
+      PackArgs(values, type_codes, 
static_cast<int>(DiscoAction::kDebugGetFromRemote), reg_id,
+               worker_id);
+      workers_[worker_id - 1]->Send(TVMArgs(values, type_codes, 3));
+    }
+    TVMArgs args = this->RecvReplyPacked(worker_id);
+    ICHECK_EQ(args.size(), 2);
+    ICHECK(static_cast<DiscoAction>(args[0].operator int()) == 
DiscoAction::kDebugGetFromRemote);
+    TVMRetValue result;
+    result = args[1];
+    return result;
+  }
+
+  void DebugSetRegister(int64_t reg_id, TVMArgValue value, int worker_id) {
+    if (worker_id == 0) {
+      this->SyncWorker(worker_id);
+      worker_0_->worker->SetRegister(reg_id, value);
+      return;
+    }
+    ObjectRef wrapped{nullptr};
+    if (value.type_code() == kTVMNDArrayHandle || value.type_code() == 
kTVMObjectHandle) {
+      wrapped = DiscoDebugObject::Wrap(value);
+      TVMValue tvm_value;
+      int type_code = kTVMObjectHandle;
+      tvm_value.v_handle = const_cast<Object*>(wrapped.get());
+      value = TVMArgValue(tvm_value, type_code);
+    }
+    {
+      TVMValue values[4];
+      int type_codes[4];
+      PackArgs(values, type_codes, 
static_cast<int>(DiscoAction::kDebugSetRegister), reg_id,
+               worker_id, value);
+      workers_[worker_id - 1]->Send(TVMArgs(values, type_codes, 4));
+    }
+    TVMRetValue result;
+    TVMArgs args = this->RecvReplyPacked(worker_id);
+    ICHECK_EQ(args.size(), 1);
+    ICHECK(static_cast<DiscoAction>(args[0].operator int()) == 
DiscoAction::kDebugSetRegister);
+  }
+
+  void BroadcastPacked(const TVMArgs& args) final {
+    worker_0_->channel->Send(args);
+    for (std::unique_ptr<DiscoProcessChannel>& channel : workers_) {
+      channel->Send(args);
+    }
+  }
+
+  TVMArgs RecvReplyPacked(int worker_id) final {
+    if (worker_id == 0) {
+      return worker_0_->channel->RecvReply();
+    }
+    return this->workers_.at(worker_id - 1)->RecvReply();
+  }
+
+  PackedFunc process_pool_;
+  std::unique_ptr<DiscoWorkerThread> worker_0_;
+  std::vector<std::unique_ptr<DiscoProcessChannel>> workers_;
+
+  static constexpr const char* _type_key = "runtime.disco.ProcessSession";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ProcessSessionObj, SessionObj);
+};
+
+TVM_REGISTER_OBJECT_TYPE(DiscoDebugObject);
+TVM_REGISTER_OBJECT_TYPE(ProcessSessionObj);
+
+Session Session::ProcessSession(int num_workers, String process_pool_creator) {
+  const PackedFunc* pf = Registry::Get(process_pool_creator);
+  CHECK(pf) << "ValueError: Cannot find function " << process_pool_creator
+            << " in the registry. Please check if it is registered.";
+  PackedFunc process_pool = (*pf)(num_workers);
+  auto n = make_object<ProcessSessionObj>(num_workers, process_pool);
+  return Session(n);
+}
+
+void WorkerProcess(int worker_id, int num_workers, int64_t read_fd, int64_t 
write_fd) {
+  DiscoProcessChannel channel(read_fd, write_fd);
+  DiscoWorker worker(worker_id, num_workers, nullptr, &channel);
+  worker.MainLoop();
+}
+
+TVM_REGISTER_GLOBAL("runtime.disco.SessionProcess").set_body_typed(Session::ProcessSession);
+TVM_REGISTER_GLOBAL("runtime.disco.WorkerProcess").set_body_typed(WorkerProcess);
+
+}  // namespace runtime
+}  // namespace tvm
diff --git a/src/runtime/disco/protocol.h b/src/runtime/disco/protocol.h
new file mode 100644
index 0000000000..50a6b091af
--- /dev/null
+++ b/src/runtime/disco/protocol.h
@@ -0,0 +1,254 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_RUNTIME_DISCO_PROTOCOL_H_
+#define TVM_RUNTIME_DISCO_PROTOCOL_H_
+
+#include <dmlc/io.h>
+#include <dmlc/memory_io.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/disco/session.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "../../support/arena.h"
+#include "../../support/base64.h"
+#include "../minrpc/rpc_reference.h"
+
+namespace tvm {
+namespace runtime {
+
+/*!
+ * \brief The communication protocol used by Disco message channel.
+ * \tparam SubClassType The subclass type that inherits this protocol.
+ */
+template <class SubClassType>
+struct DiscoProtocol {
+ protected:
+  /*! \brief Virtual destructor */
+  virtual ~DiscoProtocol() = default;
+
+  /*! \brief Recycle all the memory used in the arena */
+  inline void RecycleAll() {
+    this->object_arena_.clear();
+    this->arena_.RecycleAll();
+  }
+
+  /*! \brief Get the length of the object being serialized. Used by 
RPCReference. */
+  inline uint64_t GetObjectBytes(Object* obj);
+
+  /*! \brief Write the object to stream. Used by RPCReference. */
+  inline void WriteObject(Object* obj);
+
+  /*! \brief Read the object from stream. Used by RPCReference. */
+  inline void ReadObject(int* tcode, TVMValue* value);
+
+  /*! \brief Callback method used when starting a new message. Used by 
RPCReference. */
+  void MessageStart(uint64_t packet_nbytes) {}
+
+  /*! \brief Callback method used when a new message is complete. Used by 
RPCReference. */
+  void MessageDone() {}
+
+  /*! \brief Callback method when an error occurs in (de)-serialization. Used 
by RPCReference. */
+  void ThrowError(RPCServerStatus status) {
+    LOG(FATAL) << "InternalError: Unexpected error in RPC: " << 
RPCServerStatusToString(status);
+  }
+
+  /*!\ brief Arena used by RPCReference to allocate POD memory */
+  template <typename T>
+  T* ArenaAlloc(int count) {
+    static_assert(std::is_pod<T>::value, "need to be trival");
+    return arena_.template allocate_<T>(count);
+  }
+
+  support::Arena arena_;
+  std::vector<ObjectRef> object_arena_;
+  friend struct RPCReference;
+};
+
+/*!
+ * \brief The debug extension of the communication protocol that allows 
serialization and
+ * deserialization of NDArrays and reflection-capable TVM objects.
+ */
+struct DiscoDebugObject : public Object {
+ public:
+  /*! \brief The data to be serialized */
+  TVMRetValue data;
+
+  /*! \brief Wrap an NDArray or reflection-capable TVM object into the debug 
extension. */
+  static ObjectRef Wrap(const TVMRetValue& data) {
+    ObjectPtr<DiscoDebugObject> n = make_object<DiscoDebugObject>();
+    n->data = data;
+    return ObjectRef(n);
+  }
+
+  /*! \brief Wrap an NDArray or reflection-capable TVM object into the debug 
extension. */
+  static ObjectRef Wrap(const TVMArgValue& data) {
+    TVMRetValue rv;
+    rv = data;
+    return Wrap(std::move(rv));
+  }
+
+  /*! \brief Serialize the debug object to string */
+  inline std::string SaveToStr() const;
+  /*! \brief Deserialize the debug object from string */
+  static inline ObjectPtr<DiscoDebugObject> LoadFromStr(std::string json_str);
+  /*! \brief Get the size of the debug object in bytes */
+  inline uint64_t GetObjectBytes() const { return sizeof(uint64_t) + 
this->SaveToStr().size(); }
+
+  static constexpr const char* _type_key = "runtime.disco.DiscoDebugObject";
+  TVM_DECLARE_FINAL_OBJECT_INFO(DiscoDebugObject, SessionObj);
+};
+
+template <class SubClassType>
+inline uint64_t DiscoProtocol<SubClassType>::GetObjectBytes(Object* obj) {
+  if (obj->IsInstance<DRefObj>()) {
+    return sizeof(uint32_t) + sizeof(int64_t);
+  } else if (obj->IsInstance<StringObj>()) {
+    uint64_t size = static_cast<StringObj*>(obj)->size;
+    return sizeof(uint32_t) + sizeof(uint64_t) + size * sizeof(char);
+  } else if (obj->IsInstance<ShapeTupleObj>()) {
+    uint64_t ndim = static_cast<ShapeTupleObj*>(obj)->size;
+    return sizeof(uint32_t) + sizeof(uint64_t) + ndim * 
sizeof(ShapeTupleObj::index_type);
+  } else if (obj->IsInstance<DiscoDebugObject>()) {
+    return sizeof(uint32_t) + 
static_cast<DiscoDebugObject*>(obj)->GetObjectBytes();
+  } else {
+    LOG(FATAL) << "ValueError: Object type is not supported in Disco calling 
convention: "
+               << obj->GetTypeKey() << " (type_index = " << obj->type_index() 
<< ")";
+  }
+}
+template <class SubClassType>
+inline void DiscoProtocol<SubClassType>::WriteObject(Object* obj) {
+  SubClassType* self = static_cast<SubClassType*>(this);
+  if (obj->IsInstance<DRefObj>()) {
+    int64_t reg_id = static_cast<DRefObj*>(obj)->reg_id;
+    self->template Write<uint32_t>(TypeIndex::kRuntimeDiscoDRef);
+    self->template Write<int64_t>(reg_id);
+  } else if (obj->IsInstance<StringObj>()) {
+    StringObj* str = static_cast<StringObj*>(obj);
+    self->template Write<uint32_t>(TypeIndex::kRuntimeString);
+    self->template Write<uint64_t>(str->size);
+    self->template WriteArray<char>(str->data, str->size);
+  } else if (obj->IsInstance<ShapeTupleObj>()) {
+    ShapeTupleObj* shape = static_cast<ShapeTupleObj*>(obj);
+    self->template Write<uint32_t>(TypeIndex::kRuntimeShapeTuple);
+    self->template Write<uint64_t>(shape->size);
+    self->template WriteArray<ShapeTupleObj::index_type>(shape->data, 
shape->size);
+  } else if (obj->IsInstance<DiscoDebugObject>()) {
+    self->template Write<uint32_t>(TypeIndex::kRoot);
+    std::string str = static_cast<DiscoDebugObject*>(obj)->SaveToStr();
+    self->template Write<uint64_t>(str.size());
+    self->template WriteArray<char>(str.data(), str.size());
+  } else {
+    LOG(FATAL) << "ValueError: Object type is not supported in Disco calling 
convention: "
+               << obj->GetTypeKey() << " (type_index = " << obj->type_index() 
<< ")";
+  }
+}
+
+template <class SubClassType>
+inline void DiscoProtocol<SubClassType>::ReadObject(int* tcode, TVMValue* 
value) {
+  SubClassType* self = static_cast<SubClassType*>(this);
+  ObjectRef result{nullptr};
+  uint32_t type_index;
+  self->template Read<uint32_t>(&type_index);
+  if (type_index == TypeIndex::kRuntimeDiscoDRef) {
+    ObjectPtr<DRefObj> dref = make_object<DRefObj>();
+    self->template Read<int64_t>(&dref->reg_id);
+    dref->session = Session{nullptr};
+    result = ObjectRef(std::move(dref));
+  } else if (type_index == TypeIndex::kRuntimeString) {
+    uint64_t size = 0;
+    self->template Read<uint64_t>(&size);
+    std::string data(size, '\0');
+    self->template ReadArray<char>(data.data(), size);
+    result = String(std::move(data));
+  } else if (type_index == TypeIndex::kRuntimeShapeTuple) {
+    uint64_t ndim = 0;
+    self->template Read<uint64_t>(&ndim);
+    std::vector<ShapeTupleObj::index_type> data(ndim);
+    self->template ReadArray<ShapeTupleObj::index_type>(data.data(), ndim);
+    result = ShapeTuple(std::move(data));
+  } else if (type_index == TypeIndex::kRoot) {
+    uint64_t size = 0;
+    self->template Read<uint64_t>(&size);
+    std::string data(size, '\0');
+    self->template ReadArray<char>(data.data(), size);
+    result = DiscoDebugObject::LoadFromStr(std::move(data))->data;
+  } else {
+    LOG(FATAL) << "ValueError: Object type is not supported in Disco calling 
convention: "
+               << Object::TypeIndex2Key(type_index) << " (type_index = " << 
type_index << ")";
+  }
+  TVMArgsSetter(value, tcode)(0, result);
+  object_arena_.push_back(result);
+}
+
+inline std::string DiscoDebugObject::SaveToStr() const {
+  if (this->data.type_code() == kTVMObjectHandle) {
+    ObjectRef obj = this->data;
+    const PackedFunc* f = runtime::Registry::Get("node.SaveJSON");
+    CHECK(f) << "ValueError: Cannot serialize object in non-debugging mode: " 
<< obj->GetTypeKey();
+    std::string result = (*f)(obj);
+    result.push_back('0');
+    return result;
+  } else if (this->data.type_code() == kTVMNDArrayHandle) {
+    NDArray array = this->data;
+    std::string result;
+    {
+      dmlc::MemoryStringStream mstrm(&result);
+      support::Base64OutStream b64strm(&mstrm);
+      runtime::SaveDLTensor(&b64strm, array.operator->());
+      b64strm.Finish();
+    }
+    result.push_back('1');
+    return result;
+  }
+  LOG(FATAL) << "ValueError: Cannot serialize the following type code in 
non-debugging mode: "
+             << this->data.type_code() << "(" << 
ArgTypeCode2Str(this->data.type_code());
+}
+
+inline ObjectPtr<DiscoDebugObject> DiscoDebugObject::LoadFromStr(std::string 
json_str) {
+  ICHECK(!json_str.empty());
+  char control_bit = json_str.back();
+  json_str.pop_back();
+  ObjectPtr<DiscoDebugObject> result = make_object<DiscoDebugObject>();
+  if (control_bit == '0') {
+    const PackedFunc* f = runtime::Registry::Get("node.LoadJSON");
+    CHECK(f) << "ValueError: Cannot deserialize object in non-debugging mode";
+    result->data = (*f)(json_str);
+  } else if (control_bit == '1') {
+    dmlc::MemoryStringStream mstrm(&json_str);
+    support::Base64InStream b64strm(&mstrm);
+    b64strm.InitPosition();
+    runtime::NDArray array;
+    ICHECK(array.Load(&b64strm));
+    result->data = std::move(array);
+  } else {
+    LOG(FATAL) << "ValueError: Unsupported control bit: " << control_bit
+               << ". Full string: " << json_str;
+  }
+  return result;
+}
+
+}  // namespace runtime
+}  // namespace tvm
+#endif  // TVM_RUNTIME_DISCO_PROTOCOL_H_
diff --git a/src/runtime/disco/session.cc b/src/runtime/disco/session.cc
index e22b6c6d26..2cc027151a 100644
--- a/src/runtime/disco/session.cc
+++ b/src/runtime/disco/session.cc
@@ -31,15 +31,6 @@ struct SessionObj::FFI {
   }
 };
 
-void DRefObj::DebugCopyFrom(int worker_id, NDArray source) {
-  TVMRetValue target_array = this->DebugGetFromRemote(worker_id);
-  CHECK(target_array.type_code() == kTVMNDArrayHandle)
-      << "ValueError: The DRef on the remote is not an NDArray, instead, its 
type code is: "
-      << ArgTypeCode2Str(target_array.type_code());
-  NDArray target = target_array.operator NDArray();
-  target.CopyFrom(source);
-}
-
 TVM_REGISTER_OBJECT_TYPE(DRefObj);
 TVM_REGISTER_OBJECT_TYPE(SessionObj);
 
TVM_REGISTER_GLOBAL("runtime.disco.SessionThreaded").set_body_typed(Session::ThreadedSession);
@@ -58,6 +49,8 @@ TVM_REGISTER_GLOBAL("runtime.disco.SessionCopyToWorker0")
     .set_body_method<Session>(&SessionObj::CopyToWorker0);
 TVM_REGISTER_GLOBAL("runtime.disco.SessionSyncWorker")
     .set_body_method<Session>(&SessionObj::SyncWorker);
+TVM_REGISTER_GLOBAL("runtime.disco.SessionInitCCL")  //
+    .set_body_method<Session>(&SessionObj::InitCCL);
 TVM_REGISTER_GLOBAL("runtime.disco.SessionCallPacked").set_body([](TVMArgs 
args, TVMRetValue* rv) {
   Session self = args[0];
   *rv = SessionObj::FFI::CallWithPacked(
diff --git a/src/runtime/disco/threaded_session.cc 
b/src/runtime/disco/threaded_session.cc
index cb84918d2d..349601fd03 100644
--- a/src/runtime/disco/threaded_session.cc
+++ b/src/runtime/disco/threaded_session.cc
@@ -27,16 +27,17 @@
 #include <thread>
 #include <utility>
 
-#include "../../support/arena.h"
 #include "../../support/ring_buffer.h"
 #include "../minrpc/rpc_reference.h"
 #include "./bcast_session.h"
+#include "./protocol.h"
 #include "./worker.h"
 
 namespace tvm {
 namespace runtime {
 
-class DiscoThreadedMessageQueue : public dmlc::Stream {
+class DiscoThreadedMessageQueue : private dmlc::Stream,
+                                  private 
DiscoProtocol<DiscoThreadedMessageQueue> {
  public:
   void Send(const TVMArgs& args) {
     RPCReference::ReturnPackedSeq(args.values, args.type_codes, args.num_args, 
this);
@@ -67,10 +68,7 @@ class DiscoThreadedMessageQueue : public dmlc::Stream {
       condition_.wait(lock, [this] { return msg_cnt_.load() > 0; });
       --msg_cnt_;
     }
-    {
-      this->arena_.RecycleAll();
-      this->object_arena_.clear();
-    }
+    this->RecycleAll();
     uint64_t packet_nbytes = 0;
     RPCCode code = RPCCode::kReturn;
     this->Read(&packet_nbytes);
@@ -84,18 +82,6 @@ class DiscoThreadedMessageQueue : public dmlc::Stream {
     this->ring_buffer_.Reserve(n);
   }
 
-  void MessageDone() {}
-
-  void ThrowError(RPCServerStatus status) {
-    LOG(FATAL) << "InternalError: Unexpected error in RPC: " << 
RPCServerStatusToString(status);
-  }
-
-  template <typename T>
-  T* ArenaAlloc(int count) {
-    static_assert(std::is_pod<T>::value, "need to be trival");
-    return arena_.template allocate_<T>(count);
-  }
-
   size_t Read(void* data, size_t size) final {
     std::lock_guard<std::mutex> lock(mutex_);
     ring_buffer_.Read(data, size);
@@ -107,85 +93,17 @@ class DiscoThreadedMessageQueue : public dmlc::Stream {
     ring_buffer_.Write(data, size);
   }
 
-  uint64_t GetObjectBytes(Object* obj) {
-    if (obj->IsInstance<DRefObj>()) {
-      return sizeof(uint32_t) + sizeof(int64_t);
-    } else if (obj->IsInstance<StringObj>()) {
-      uint64_t size = static_cast<StringObj*>(obj)->size;
-      return sizeof(uint32_t) + sizeof(uint64_t) + size * sizeof(char);
-    } else if (obj->IsInstance<ShapeTupleObj>()) {
-      uint64_t ndim = static_cast<ShapeTupleObj*>(obj)->size;
-      return sizeof(uint32_t) + sizeof(uint64_t) + ndim * 
sizeof(ShapeTupleObj::index_type);
-    } else {
-      LOG(FATAL) << "ValueError: Object type is not supported in Disco calling 
convention: "
-                 << obj->GetTypeKey() << " (type_index = " << 
obj->type_index() << ")";
-    }
-  }
-
-  void WriteObject(Object* obj) {
-    if (obj->IsInstance<DRefObj>()) {
-      int64_t reg_id = static_cast<DRefObj*>(obj)->reg_id;
-      this->Write<uint32_t>(TypeIndex::kRuntimeDiscoDRef);
-      this->Write<int64_t>(reg_id);
-    } else if (obj->IsInstance<StringObj>()) {
-      StringObj* str = static_cast<StringObj*>(obj);
-      this->Write<uint32_t>(TypeIndex::kRuntimeString);
-      this->Write<uint64_t>(str->size);
-      this->WriteArray<char>(str->data, str->size);
-    } else if (obj->IsInstance<ShapeTupleObj>()) {
-      ShapeTupleObj* shape = static_cast<ShapeTupleObj*>(obj);
-      this->Write<uint32_t>(TypeIndex::kRuntimeShapeTuple);
-      this->Write<uint64_t>(shape->size);
-      this->WriteArray<ShapeTupleObj::index_type>(shape->data, shape->size);
-    } else {
-      LOG(FATAL) << "ValueError: Object type is not supported in Disco calling 
convention: "
-                 << obj->GetTypeKey() << " (type_index = " << 
obj->type_index() << ")";
-    }
-  }
-
-  void ReadObject(int* tcode, TVMValue* value) {
-    ObjectRef result{nullptr};
-    uint32_t type_index;
-    this->Read<uint32_t>(&type_index);
-    if (type_index == TypeIndex::kRuntimeDiscoDRef) {
-      ObjectPtr<DRefObj> dref = make_object<DRefObj>();
-      this->Read<int64_t>(&dref->reg_id);
-      dref->session = Session{nullptr};
-      result = ObjectRef(std::move(dref));
-    } else if (type_index == TypeIndex::kRuntimeString) {
-      uint64_t size = 0;
-      this->Read<uint64_t>(&size);
-      std::string data(size, '\0');
-      this->ReadArray<char>(data.data(), size);
-      result = String(std::move(data));
-    } else if (type_index == TypeIndex::kRuntimeShapeTuple) {
-      uint64_t ndim = 0;
-      this->Read<uint64_t>(&ndim);
-      std::vector<ShapeTupleObj::index_type> data(ndim);
-      this->ReadArray<ShapeTupleObj::index_type>(data.data(), ndim);
-      result = ShapeTuple(std::move(data));
-    } else {
-      LOG(FATAL) << "ValueError: Object type is not supported in Disco calling 
convention: "
-                 << Object::TypeIndex2Key(type_index) << " (type_index = " << 
type_index << ")";
-    }
-    *tcode = kTVMObjectHandle;
-    value->v_handle = const_cast<Object*>(result.get());
-    object_arena_.push_back(result);
-  }
-
   using dmlc::Stream::Read;
   using dmlc::Stream::ReadArray;
   using dmlc::Stream::Write;
   using dmlc::Stream::WriteArray;
   friend struct RPCReference;
+  friend struct DiscoProtocol<DiscoThreadedMessageQueue>;
 
   std::mutex mutex_;
   std::atomic<int> msg_cnt_{0};
   std::condition_variable condition_;
-
   support::RingBuffer ring_buffer_;
-  support::Arena arena_;
-  std::vector<ObjectRef> object_arena_;
 };
 
 class DiscoThreadChannel final : public DiscoChannel {
@@ -199,44 +117,52 @@ class DiscoThreadChannel final : public DiscoChannel {
   DiscoThreadedMessageQueue worker_to_controler_;
 };
 
+DiscoWorkerThread::DiscoWorkerThread(int worker_id, int num_workers,
+                                     WorkerZeroData* worker_zero_data_)
+    : channel(std::make_unique<DiscoThreadChannel>()),
+      worker(
+          std::make_unique<DiscoWorker>(worker_id, num_workers, 
worker_zero_data_, channel.get())),
+      thread(std::make_unique<std::thread>([worker = this->worker.get()] { 
worker->MainLoop(); })) {
+}
+
 class ThreadedSessionObj final : public BcastSessionObj {
  public:
   explicit ThreadedSessionObj(int num_workers) {
     for (int i = 0; i < num_workers; ++i) {
-      std::unique_ptr<DiscoThreadChannel> channel = 
std::make_unique<DiscoThreadChannel>();
       WorkerZeroData* data = (i == 0) ? &worker_zero_data_ : nullptr;
-      workers_.emplace_back(std::make_unique<DiscoWorker>(i, num_workers, 
data, channel.get()));
-      channels_.emplace_back(std::move(channel));
-      worker_threads_.emplace_back([worker = workers_.back().get()] { 
worker->MainLoop(); });
+      workers_.emplace_back(i, num_workers, data);
     }
   }
 
   ~ThreadedSessionObj() {
     this->Shutdown();
-    for (std::thread& worker : this->worker_threads_) {
-      worker.join();
-    }
+    workers_.clear();
   }
 
   TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) {
     this->SyncWorker(worker_id);
-    return this->workers_.at(worker_id)->register_file.at(reg_id);
+    return this->workers_.at(worker_id).worker->register_file.at(reg_id);
+  }
+
+  void DebugSetRegister(int64_t reg_id, TVMArgValue value, int worker_id) {
+    this->SyncWorker(worker_id);
+    this->workers_.at(worker_id).worker->SetRegister(reg_id, value);
   }
 
   void BroadcastPacked(const TVMArgs& args) final {
-    for (const std::unique_ptr<DiscoThreadChannel>& channel : this->channels_) 
{
-      channel->Send(args);
+    for (const DiscoWorkerThread& worker : this->workers_) {
+      worker.channel->Send(args);
     }
   }
 
-  TVMArgs RecvReplyPacked(int worker_id) final { return 
channels_[worker_id]->RecvReply(); }
+  TVMArgs RecvReplyPacked(int worker_id) final {
+    return this->workers_.at(worker_id).channel->RecvReply();
+  }
 
   static constexpr const char* _type_key = "runtime.disco.ThreadedSession";
   TVM_DECLARE_FINAL_OBJECT_INFO(ThreadedSessionObj, SessionObj);
 
-  std::vector<std::unique_ptr<DiscoThreadChannel>> channels_;
-  std::vector<std::unique_ptr<DiscoWorker>> workers_;
-  std::vector<std::thread> worker_threads_;
+  std::vector<DiscoWorkerThread> workers_;
 };
 
 TVM_REGISTER_OBJECT_TYPE(ThreadedSessionObj);
diff --git a/src/runtime/disco/worker.cc b/src/runtime/disco/worker.cc
index 63e814a7e2..3100985f18 100644
--- a/src/runtime/disco/worker.cc
+++ b/src/runtime/disco/worker.cc
@@ -19,11 +19,15 @@
 #include "./worker.h"
 
 #include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/disco/session.h>
+#include <tvm/runtime/packed_func.h>
 #include <tvm/runtime/registry.h>
 
 #include <thread>
 
+#include "../../support/process_id.h"
 #include "./builtin.h"
+#include "./protocol.h"
 
 namespace tvm {
 namespace runtime {
@@ -43,11 +47,23 @@ DiscoWorker* DiscoWorker::ThreadLocal() {
   return ret;
 }
 
+void DiscoWorker::SetRegister(int reg_id, TVMArgValue value) {
+  ICHECK(0 <= reg_id && reg_id < static_cast<int>(register_file.size()));
+  TVMRetValue& rv = register_file.at(reg_id);
+  if (rv.type_code() == kTVMNDArrayHandle && value.type_code() == 
kTVMNDArrayHandle) {
+    NDArray dst = rv;
+    NDArray src = value;
+    dst.CopyFrom(src);
+  } else {
+    rv = value;
+  }
+}
+
 struct DiscoWorker::Impl {
   static void MainLoop(DiscoWorker* self) {
     ThreadLocalDiscoWorker::Get()->worker = self;
-    LOG(INFO) << "[Thread " << std::this_thread::get_id() << "] Worker #" << 
self->worker_id
-              << " Launched";
+    LOG(INFO) << "[Worker #" << self->worker_id << "] " << 
support::GetProcessIdAndThreadIdHeader()
+              << " started";
     while (true) {
       TVMArgs args = self->channel->Recv();
       DiscoAction action = static_cast<DiscoAction>(args[0].operator int());
@@ -84,6 +100,17 @@ struct DiscoWorker::Impl {
           SyncWorker(self, reg_id);
           break;
         }
+        case DiscoAction::kDebugGetFromRemote: {
+          int worker_id = args[2];
+          DebugGetFromRemote(self, reg_id, worker_id);
+          break;
+        }
+        case DiscoAction::kDebugSetRegister: {
+          int worker_id = args[2];
+          TVMArgValue value = args[3];
+          DebugSetRegister(self, reg_id, worker_id, value);
+          break;
+        }
       }
     }
   }
@@ -131,6 +158,30 @@ struct DiscoWorker::Impl {
     }
   }
 
+  static void DebugGetFromRemote(DiscoWorker* self, int reg_id, int worker_id) 
{
+    if (worker_id == self->worker_id) {
+      TVMRetValue rv = GetReg(self, reg_id);
+      if (rv.type_code() == kTVMNDArrayHandle || rv.type_code() == 
kTVMObjectHandle) {
+        rv = DiscoDebugObject::Wrap(rv);
+      }
+      TVMValue values[2];
+      int type_codes[2];
+      PackArgs(values, type_codes, 
static_cast<int>(DiscoAction::kDebugGetFromRemote), rv);
+      self->channel->Reply(TVMArgs(values, type_codes, 2));
+    }
+  }
+
+  static void DebugSetRegister(DiscoWorker* self, int reg_id, int worker_id, 
TVMArgValue value) {
+    if (worker_id == self->worker_id) {
+      ::tvm::runtime::SyncWorker();
+      self->SetRegister(reg_id, value);
+      TVMValue values[1];
+      int type_codes[1];
+      PackArgs(values, type_codes, 
static_cast<int>(DiscoAction::kDebugSetRegister));
+      self->channel->Reply(TVMArgs(values, type_codes, 1));
+    }
+  }
+
   static void CallPacked(DiscoWorker* self, int64_t ret_reg_id, PackedFunc 
func,
                          const TVMArgs& args) {
     TVMValue* values = const_cast<TVMValue*>(args.values);
diff --git a/src/runtime/disco/worker.h b/src/runtime/disco/worker.h
index f10382b068..e948fa1668 100644
--- a/src/runtime/disco/worker.h
+++ b/src/runtime/disco/worker.h
@@ -33,6 +33,7 @@
 #include <memory>
 #include <mutex>
 #include <queue>
+#include <thread>
 #include <utility>
 #include <vector>
 
@@ -81,6 +82,8 @@ class DiscoWorker {
   void MainLoop();
   /*! \brief Get the worker instance on the current thread */
   static DiscoWorker* ThreadLocal();
+  /*! \brief Set the specific register to a specific value */
+  void SetRegister(int reg_id, TVMArgValue value);
 
   /*! \brief The id of the worker.*/
   int worker_id;
@@ -108,6 +111,46 @@ class DiscoWorker {
   friend struct DiscoWorker::Impl;
 };
 
+/*!
+ * \brief A worker thread in Disco, which upon creation, launches a new thread 
to run the
+ * DiscoWorker.
+ * \sa DiscoWorker
+ */
+class DiscoWorkerThread {
+ public:
+  /*!
+   * \brief Construct a worker thread.
+   * \param worker_id The id of the worker.
+   * \param num_workers The total number of workers.
+   * \param worker_zero_data_ The data shared between worker-0 and the 
controler. It's a nullptr if
+   * the worker is not worker-0.
+   */
+  explicit DiscoWorkerThread(int worker_id, int num_workers, WorkerZeroData* 
worker_zero_data_);
+
+  /*! \brief Move constructor. */
+  explicit DiscoWorkerThread(DiscoWorkerThread&& other)
+      : channel(std::move(other.channel)),
+        worker(std::move(other.worker)),
+        thread(std::move(other.thread)) {}
+
+  /*! \brief Copy constructor is disabled */
+  DiscoWorkerThread(const DiscoWorkerThread& other) = delete;
+
+  /*! \brief Destructor that joins the thread before destruction */
+  ~DiscoWorkerThread() {
+    if (this->thread != nullptr) {
+      this->thread->join();
+    }
+  }
+
+  /*! \brief The communication channel between the controler and the worker */
+  std::unique_ptr<DiscoChannel> channel;
+  /*! \brief The worker whose internal state is visible to the controler */
+  std::unique_ptr<DiscoWorker> worker;
+  /*! \brief The thread that runs the worker's main loop. */
+  std::unique_ptr<std::thread> thread;
+};
+
 }  // namespace runtime
 }  // namespace tvm
 #endif  // TVM_RUNTIME_DISCO_WORKER_H_
diff --git a/src/support/process_id.h b/src/support/process_id.h
new file mode 100644
index 0000000000..8462ae0dd2
--- /dev/null
+++ b/src/support/process_id.h
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file pipe.h
+ * \brief Platform independent pipe, used for IPC.
+ */
+#ifndef TVM_SUPPORT_PROCESS_ID_H_
+#define TVM_SUPPORT_PROCESS_ID_H_
+
+#include <iomanip>
+#include <ios>
+#include <sstream>
+#include <string>
+#include <thread>
+
+#ifdef _WIN32
+#include <windows.h>
+#else
+#include <sys/types.h>
+#include <unistd.h>
+#endif
+
+namespace tvm {
+namespace support {
+
+/*! \brief Returns the PID of the current process as an 64-bit signed integer. 
*/
+inline int64_t GetProcessId() {
+  int64_t result;
+#ifdef _WIN32
+  DWORD pid = GetCurrentProcessId();
+  result = static_cast<int64_t>(pid);
+#else
+  pid_t pid = getpid();
+  result = static_cast<int64_t>(pid);
+#endif
+  return result;
+}
+
+/*! \brief Returns the PID and TIR of the current process/thread as a 
formatted string */
+inline std::string GetProcessIdAndThreadIdHeader() {
+  std::ostringstream os;
+  os << "[PID " << GetProcessId() << " TID 0x" << std::setw(16) << 
std::setfill('0') << std::hex
+     << std::this_thread::get_id() << "]";
+  return os.str();
+}
+
+}  // namespace support
+}  // namespace tvm
+
+#endif  // TVM_SUPPORT_PROCESS_ID_H_
diff --git a/tests/python/disco/test_nccl.py b/tests/python/disco/test_nccl.py
index 6507af5699..f0f949ab80 100644
--- a/tests/python/disco/test_nccl.py
+++ b/tests/python/disco/test_nccl.py
@@ -19,30 +19,33 @@
 import tempfile
 
 import numpy as np
+import pytest
 
 import tvm
-import tvm.testing
 from tvm import dlight as dl
 from tvm import relax as rx
 from tvm.runtime import disco as di
 from tvm.runtime.relax_vm import VirtualMachine
 from tvm.script import relax as R
 
+_all_session_kinds = [di.ThreadedSession, di.ProcessSession]
 
-def test_init():
-    devices = [0, 1]
 
-    sess = di.ThreadedSession(num_workers=len(devices))
[email protected]("session_kind", _all_session_kinds)
+def test_init(session_kind):
+    devices = [0, 1]
+    sess = session_kind(num_workers=len(devices))
     sess.init_ccl("nccl", *devices)
 
 
-def test_allreduce():
[email protected]("session_kind", _all_session_kinds)
+def test_allreduce(session_kind):
     devices = [0, 1]
+    sess = session_kind(num_workers=len(devices))
+    sess.init_ccl("nccl", *devices)
+
     array_1 = np.arange(12, dtype="float32").reshape(3, 4)
     array_2 = np.arange(start=1, stop=-11, step=-1, 
dtype="float32").reshape(3, 4)
-
-    sess = di.ThreadedSession(num_workers=len(devices))
-    sess.init_ccl("nccl", *devices)
     d_array = sess.empty((3, 4), "float32")
     d_array.debug_copy_from(0, array_1)
     d_array.debug_copy_from(1, array_2)
@@ -60,12 +63,13 @@ def test_allreduce():
         np.testing.assert_equal(result, expected)
 
 
-def test_broadcast_from_worker0():
[email protected]("session_kind", _all_session_kinds)
+def test_broadcast_from_worker0(session_kind):
     devices = [0, 1]
-    array = np.arange(12, dtype="float32").reshape(3, 4)
-
-    sess = di.ThreadedSession(num_workers=len(devices))
+    sess = session_kind(num_workers=len(devices))
     sess.init_ccl("nccl", *devices)
+
+    array = np.arange(12, dtype="float32").reshape(3, 4)
     d_array = sess.empty((3, 4), "float32")
     d_array.debug_copy_from(0, array)
     dst_array = sess.empty((3, 4), "float32")
@@ -74,12 +78,13 @@ def test_broadcast_from_worker0():
     np.testing.assert_equal(result, array)
 
 
-def test_scatter():
[email protected]("session_kind", _all_session_kinds)
+def test_scatter(session_kind):
     devices = [0, 1]
-    array = np.arange(36, dtype="float32").reshape(3, 4, 3)
-
-    sess = di.ThreadedSession(num_workers=len(devices))
+    sess = session_kind(num_workers=len(devices))
     sess.init_ccl("nccl", *devices)
+
+    array = np.arange(36, dtype="float32").reshape(3, 4, 3)
     d_src = sess.empty((3, 4, 3), "float32")
     d_dst = sess.empty((3, 3, 2), "float32")
 
@@ -96,29 +101,29 @@ def test_scatter():
     )
 
 
-# def test_gather():
-#     num_workers = 2
-#     devices = [1, 2]
-#     array = np.arange(36, dtype="float32")
-
-#     sess = di.ThreadedSession(num_workers=num_workers)
-#     sess.init_ccl("nccl", *devices)
-#     d_src = sess.empty((3, 3, 2), "float32")
-#     d_dst = sess.empty((3, 4, 3), "float32")
-
-#     d_src.debug_copy_from(0, array[:18])
-#     d_src.debug_copy_from(1, array[18:])
-
-#     sess.gather_to_worker0(d_src, d_dst)
[email protected]("session_kind", _all_session_kinds)
+def test_gather(session_kind):
+    devices = [1, 2]
+    sess = session_kind(num_workers=len(devices))
+    sess.init_ccl("nccl", *devices)
 
-#     np.testing.assert_equal(
-#         d_dst.debug_get_from_remote(0).numpy(),
-#         array.reshape(3, 4, 3),
-#     )
+    array = np.arange(36, dtype="float32")
+    d_src = sess.empty((3, 3, 2), "float32")
+    d_dst = sess.empty((3, 4, 3), "float32")
+    d_src.debug_copy_from(0, array[:18])
+    d_src.debug_copy_from(1, array[18:])
+    sess.gather_to_worker0(d_src, d_dst)
+    np.testing.assert_equal(
+        d_dst.debug_get_from_remote(0).numpy(),
+        array.reshape(3, 4, 3),
+    )
 
 
-def test_mlp():  # pylint: disable=too-many-locals
[email protected]("session_kind", _all_session_kinds)
+def test_mlp(session_kind):  # pylint: disable=too-many-locals
     devices = [0, 1]
+    sess = session_kind(num_workers=len(devices))
+    sess.init_ccl("nccl", *devices)
 
     # pylint: disable=invalid-name
     @tvm.script.ir_module
@@ -193,8 +198,6 @@ def test_mlp():  # pylint: disable=too-many-locals
         path = tmpdir + "/test.so"
         relax_build(ShardedMLP, target).export_library(path)
 
-        sess = di.ThreadedSession(num_workers=len(devices))
-        sess.init_ccl("nccl", *devices)
         mod = sess.load_vm_module(path)
 
         d_X = sess.empty((128, 128), "float32")
@@ -215,8 +218,11 @@ def test_mlp():  # pylint: disable=too-many-locals
     np.testing.assert_allclose(Y_result, Y_expected, rtol=1e-4, atol=1e-4)
 
 
-def test_attention():  # pylint: disable=too-many-locals,too-many-statements
[email protected]("session_kind", _all_session_kinds)
+def test_attention(session_kind):  # pylint: 
disable=too-many-locals,too-many-statements
     devices = [0, 1]
+    sess = session_kind(num_workers=len(devices))
+    sess.init_ccl("nccl", *devices)
 
     # pylint: disable=invalid-name
     @tvm.script.ir_module
@@ -343,8 +349,6 @@ def test_attention():  # pylint: 
disable=too-many-locals,too-many-statements
         path = tmpdir + "/test.so"
         relax_build(ShardedAttention, target).export_library(path)
 
-        sess = di.ThreadedSession(num_workers=len(devices))
-        sess.init_ccl("nccl", *devices)
         mod = sess.load_vm_module(path)
 
         d_X = sess.empty((1, 10, 128), "float32")
@@ -372,4 +376,10 @@ def test_attention():  # pylint: 
disable=too-many-locals,too-many-statements
 
 
 if __name__ == "__main__":
-    tvm.testing.main()
+    test_init(di.ProcessSession)
+    test_allreduce(di.ProcessSession)
+    test_broadcast_from_worker0(di.ProcessSession)
+    test_scatter(di.ProcessSession)
+    test_gather(di.ProcessSession)
+    test_mlp(di.ProcessSession)
+    test_attention(di.ProcessSession)
diff --git a/tests/python/disco/test_session.py 
b/tests/python/disco/test_session.py
index a2c0906f22..40dcb04911 100644
--- a/tests/python/disco/test_session.py
+++ b/tests/python/disco/test_session.py
@@ -19,15 +19,16 @@
 import tempfile
 
 import numpy as np
+import pytest
 
 import tvm
 from tvm import relax as rx
-from tvm._ffi import register_func
 from tvm.runtime import ShapeTuple, String
 from tvm.runtime import disco as di
 from tvm.script import ir as I
 from tvm.script import relax as R
 from tvm.script import tir as T
+from tvm.testing import disco as _
 
 
 def _numpy_to_worker_0(sess: di.Session, np_array: np.array, device):
@@ -44,29 +45,23 @@ def _numpy_from_worker_0(sess: di.Session, remote_array, 
shape, dtype):
     return host_array.numpy()
 
 
-def test_int():
-    num_workers = 4
+_all_session_kinds = [di.ThreadedSession, di.ProcessSession]
 
-    @register_func("tests.disco.add_one", override=True)
-    def add_one(x: int) -> int:  # pylint: disable=invalid-name
-        return x + 1
 
-    sess = di.ThreadedSession(num_workers=num_workers)
[email protected]("session_kind", _all_session_kinds)
+def test_int(session_kind):  # pylint: disable=invalid-name
+    num_workers = 4
+    sess = session_kind(num_workers=num_workers)
     func: di.DPackedFunc = sess.get_global_func("tests.disco.add_one")
     result: di.DRef = func(1)
-
     for i in range(num_workers):
         assert result.debug_get_from_remote(i) == 2
 
 
-def test_float():
[email protected]("session_kind", _all_session_kinds)
+def test_float(session_kind):
     num_workers = 4
-
-    @register_func("tests.disco.add_one_float", override=True)
-    def add_one(x: float):  # pylint: disable=invalid-name
-        return x + 0.5
-
-    sess = di.ThreadedSession(num_workers=num_workers)
+    sess = session_kind(num_workers=num_workers)
     func: di.DPackedFunc = sess.get_global_func("tests.disco.add_one_float")
     result: di.DRef = func(1.5)
 
@@ -74,32 +69,23 @@ def test_float():
         assert result.debug_get_from_remote(i) == 2.0
 
 
-def test_ndarray():
[email protected]("session_kind", _all_session_kinds)
+def test_ndarray(session_kind):
     num_workers = 4
-
-    @register_func("tests.disco.add_one_ndarray", override=True)
-    def add_one(x: tvm.runtime.NDArray) -> tvm.runtime.NDArray:  # pylint: 
disable=invalid-name
-        return tvm.nd.array(x.numpy() + 1)
-
+    sess = session_kind(num_workers=num_workers)
     device = tvm.cpu(0)
     x_np = np.arange(6).astype("float32").reshape([2, 3])
     y_np = np.arange(6).astype("float32").reshape([2, 3]) + 1
-
-    sess = di.ThreadedSession(num_workers=num_workers)
     x_disc = _numpy_to_worker_0(sess, x_np, device=device)
     y_disc = sess.get_global_func("tests.disco.add_one_ndarray")(x_disc)
     y_nd = _numpy_from_worker_0(sess, y_disc, shape=y_np.shape, 
dtype=y_np.dtype)
     np.testing.assert_equal(y_nd, y_np)
 
 
-def test_string():
[email protected]("session_kind", _all_session_kinds)
+def test_string(session_kind):
     num_workers = 4
-
-    @register_func("tests.disco.str", override=True)
-    def my_str_func(x: str):  # pylint: disable=invalid-name
-        return x + "_suffix"
-
-    sess = di.ThreadedSession(num_workers=num_workers)
+    sess = session_kind(num_workers=num_workers)
     func: di.DPackedFunc = sess.get_global_func("tests.disco.str")
     result: di.DRef = func("hello")
 
@@ -107,15 +93,10 @@ def test_string():
         assert result.debug_get_from_remote(i) == "hello_suffix"
 
 
-def test_string_obj():
[email protected]("session_kind", _all_session_kinds)
+def test_string_obj(session_kind):
     num_workers = 4
-
-    @register_func("tests.disco.str_obj", override=True)
-    def my_str_func(x: String):  # pylint: disable=invalid-name
-        assert isinstance(x, String)
-        return String(x + "_suffix")
-
-    sess = di.ThreadedSession(num_workers=num_workers)
+    sess = session_kind(num_workers=num_workers)
     func: di.DPackedFunc = sess.get_global_func("tests.disco.str_obj")
     result: di.DRef = func(String("hello"))
 
@@ -125,26 +106,22 @@ def test_string_obj():
         assert value == "hello_suffix"
 
 
-def test_shape_tuple():
[email protected]("session_kind", _all_session_kinds)
+def test_shape_tuple(session_kind):
     num_workers = 4
-
-    @register_func("tests.disco.shape_tuple", override=True)
-    def my_str_func(x: ShapeTuple):  # pylint: disable=invalid-name
-        assert isinstance(x, ShapeTuple)
-        return ShapeTuple(list(x) + [4, 5])
-
-    sess = di.ThreadedSession(num_workers=num_workers)
+    sess = session_kind(num_workers=num_workers)
     func: di.DPackedFunc = sess.get_global_func("tests.disco.shape_tuple")
     result: di.DRef = func(ShapeTuple([1, 2, 3]))
-
     for i in range(num_workers):
         value = result.debug_get_from_remote(i)
         assert isinstance(value, ShapeTuple)
         assert list(value) == [1, 2, 3, 4, 5]
 
 
-def test_vm_module():
[email protected]("session_kind", _all_session_kinds)
+def test_vm_module(session_kind):
     num_workers = 4
+    sess = session_kind(num_workers=num_workers)
 
     # pylint: disable=invalid-name
     @I.ir_module
@@ -172,7 +149,6 @@ def test_vm_module():
         y_np = x_np.transpose()
 
         rx.build(TestMod, target="llvm").export_library(path)
-        sess = di.ThreadedSession(num_workers=num_workers)
         mod = sess.load_vm_module(path, device=device)
 
         x_disc = _numpy_to_worker_0(sess, x_np, device=device)
@@ -181,8 +157,10 @@ def test_vm_module():
         np.testing.assert_equal(y_nd, y_np)
 
 
-def test_vm_multi_func():
[email protected]("session_kind", _all_session_kinds)
+def test_vm_multi_func(session_kind):
     num_workers = 4
+    sess = session_kind(num_workers=num_workers)
 
     # pylint: disable=invalid-name
     @I.ir_module
@@ -231,7 +209,6 @@ def test_vm_multi_func():
         y_np = x_np.transpose()
 
         rx.build(TestMod, target="llvm").export_library(path)
-        sess = di.ThreadedSession(num_workers=num_workers)
         mod = sess.load_vm_module(path, device=device)
 
         x_disc = _numpy_to_worker_0(sess, x_np, device=device)
@@ -244,11 +221,11 @@ def test_vm_multi_func():
 
 
 if __name__ == "__main__":
-    test_int()
-    test_float()
-    test_string()
-    test_string_obj()
-    test_shape_tuple()
-    test_ndarray()
-    test_vm_module()
-    test_vm_multi_func()
+    test_int(di.ProcessSession)
+    test_float(di.ProcessSession)
+    test_string(di.ProcessSession)
+    test_string_obj(di.ProcessSession)
+    test_shape_tuple(di.ProcessSession)
+    test_ndarray(di.ProcessSession)
+    test_vm_module(di.ProcessSession)
+    test_vm_multi_func(di.ProcessSession)

Reply via email to