This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1b6c00d756 [Disco] Implement SocketSession (#17182)
1b6c00d756 is described below

commit 1b6c00d7560afded9b5380abfd3f182461b9448d
Author: Wuwei Lin <[email protected]>
AuthorDate: Thu Jul 25 21:11:33 2024 -0700

    [Disco] Implement SocketSession (#17182)
    
    * [Disco] Implement SocketSession
    
    Implements SocketSession that connects multiple local worker
    processes/threads over multiple distributed nodes via TCP socket.
    
    * doc
    
    * lint
    
    * resolve conflcit
    
    * lint
    
    * add local worker id
    
    * lint
    
    * lint
    
    * disable for hexagon
    
    * remove from header
---
 CMakeLists.txt                                     |   6 +
 include/tvm/runtime/disco/disco_worker.h           |   4 +
 include/tvm/runtime/disco/session.h                |   1 +
 .../disco_remote_socket_session.py}                |  26 +-
 python/tvm/runtime/disco/__init__.py               |   1 +
 python/tvm/runtime/disco/session.py                |  23 ++
 src/runtime/disco/bcast_session.h                  |  20 ++
 src/runtime/disco/disco_worker.cc                  |   4 +-
 src/runtime/disco/distributed/socket_session.cc    | 332 +++++++++++++++++++++
 src/runtime/disco/message_queue.h                  | 133 +++++++++
 src/runtime/disco/nccl/nccl.cc                     |   4 +-
 src/runtime/disco/process_session.cc               | 128 ++------
 src/runtime/disco/threaded_session.cc              |   4 +
 src/support/socket.h                               |   6 +-
 tests/python/disco/test_session.py                 |  87 +++++-
 15 files changed, 660 insertions(+), 119 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 7575d6c2b4..7fba5355f0 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -387,6 +387,12 @@ if(BUILD_FOR_HEXAGON)
   add_definitions(-DDMLC_CXX11_THREAD_LOCAL=0)
 endif()
 
+# distributed disco runtime are disabled for hexagon
+if (NOT BUILD_FOR_HEXAGON)
+  tvm_file_glob(GLOB RUNTIME_DISCO_DISTRIBUTED_SRCS 
src/runtime/disco/distributed/*.cc)
+  list(APPEND RUNTIME_SRCS ${RUNTIME_DISCO_DISTRIBUTED_SRCS})
+endif()
+
 # Package runtime rules
 if(NOT USE_RTTI)
   add_definitions(-DDMLC_ENABLE_RTTI=0)
diff --git a/include/tvm/runtime/disco/disco_worker.h 
b/include/tvm/runtime/disco/disco_worker.h
index 13f94802c8..c9c85b7dbf 100644
--- a/include/tvm/runtime/disco/disco_worker.h
+++ b/include/tvm/runtime/disco/disco_worker.h
@@ -52,6 +52,7 @@ class DiscoWorker {
   explicit DiscoWorker(int worker_id, int num_workers, int num_groups,
                        WorkerZeroData* worker_zero_data, DiscoChannel* channel)
       : worker_id(worker_id),
+        local_worker_id(worker_id),
         num_workers(num_workers),
         num_groups(num_groups),
         default_device(Device{DLDeviceType::kDLCPU, 0}),
@@ -68,6 +69,9 @@ class DiscoWorker {
 
   /*! \brief The id of the worker.*/
   int worker_id;
+  /*! \brief The local id of the worker. This can be different from worker_id 
if the session is
+   * consisted with multiple sub-sessions. */
+  int local_worker_id;
   /*! \brief Total number of workers */
   int num_workers;
   /*! \brief Total number of workers */
diff --git a/include/tvm/runtime/disco/session.h 
b/include/tvm/runtime/disco/session.h
index 97fa79096d..9c34f8a2af 100644
--- a/include/tvm/runtime/disco/session.h
+++ b/include/tvm/runtime/disco/session.h
@@ -281,6 +281,7 @@ class Session : public ObjectRef {
    */
   TVM_DLL static Session ProcessSession(int num_workers, int num_groups,
                                         String process_pool_creator, String 
entrypoint);
+
   TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Session, ObjectRef, 
SessionObj);
 };
 
diff --git a/python/tvm/runtime/disco/__init__.py 
b/python/tvm/exec/disco_remote_socket_session.py
similarity index 58%
copy from python/tvm/runtime/disco/__init__.py
copy to python/tvm/exec/disco_remote_socket_session.py
index 856e69bc35..3111ce30ac 100644
--- a/python/tvm/runtime/disco/__init__.py
+++ b/python/tvm/exec/disco_remote_socket_session.py
@@ -14,12 +14,20 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM distributed runtime API."""
-from .session import (
-    DModule,
-    DPackedFunc,
-    DRef,
-    ProcessSession,
-    Session,
-    ThreadedSession,
-)
+# pylint: disable=invalid-name
+"""Launch disco session in the remote node and connect to the server."""
+import sys
+import tvm
+from . import disco_worker as _  # pylint: disable=unused-import
+
+
+if __name__ == "__main__":
+    if len(sys.argv) != 4:
+        print("Usage: <server_host> <server_port> <num_workers>")
+        sys.exit(1)
+
+    server_host = sys.argv[1]
+    server_port = int(sys.argv[2])
+    num_workers = int(sys.argv[3])
+    func = tvm.get_global_func("runtime.disco.RemoteSocketSession")
+    func(server_host, server_port, num_workers)
diff --git a/python/tvm/runtime/disco/__init__.py 
b/python/tvm/runtime/disco/__init__.py
index 856e69bc35..2ba524cade 100644
--- a/python/tvm/runtime/disco/__init__.py
+++ b/python/tvm/runtime/disco/__init__.py
@@ -22,4 +22,5 @@ from .session import (
     ProcessSession,
     Session,
     ThreadedSession,
+    SocketSession,
 )
diff --git a/python/tvm/runtime/disco/session.py 
b/python/tvm/runtime/disco/session.py
index 89ef549df3..1749942a9c 100644
--- a/python/tvm/runtime/disco/session.py
+++ b/python/tvm/runtime/disco/session.py
@@ -574,6 +574,29 @@ class ProcessSession(Session):
         func(config, os.getpid())
 
 
+@register_func("runtime.disco.create_socket_session_local_workers")
+def _create_socket_session_local_workers(num_workers) -> Session:
+    """Create the local session for each distributed node over socket 
session."""
+    return ProcessSession(num_workers)
+
+
+@register_object("runtime.disco.SocketSession")
+class SocketSession(Session):
+    """A Disco session backed by socket-based multi-node communication."""
+
+    def __init__(
+        self, num_nodes: int, num_workers_per_node: int, num_groups: int, 
host: str, port: int
+    ) -> None:
+        self.__init_handle_by_constructor__(
+            _ffi_api.SocketSession,  # type: ignore # pylint: disable=no-member
+            num_nodes,
+            num_workers_per_node,
+            num_groups,
+            host,
+            port,
+        )
+
+
 @register_func("runtime.disco._configure_structlog")
 def _configure_structlog(pickled_config: bytes, parent_pid: int) -> None:
     """Configure structlog for all disco workers
diff --git a/src/runtime/disco/bcast_session.h 
b/src/runtime/disco/bcast_session.h
index 1a4df634b7..0e4ca614d4 100644
--- a/src/runtime/disco/bcast_session.h
+++ b/src/runtime/disco/bcast_session.h
@@ -65,6 +65,16 @@ class BcastSessionObj : public SessionObj {
    * \param TVMArgs The input arguments in TVM's PackedFunc calling convention
    */
   virtual void BroadcastPacked(const TVMArgs& args) = 0;
+
+  /*!
+   * \brief Send a packed sequence to a worker. This function is usually 
called by the controler to
+   * communicate with worker-0, because the worker-0 is assumed to be always 
collocated with the
+   * controler. Sending to other workers may not be supported.
+   * \param worker_id The worker id to send the packed sequence to.
+   * \param args The packed sequence to send.
+   */
+  virtual void SendPacked(int worker_id, const TVMArgs& args) = 0;
+
   /*!
    * \brief Receive a packed sequence from a worker. This function is usually 
called by the
    * controler to communicate with worker-0, because the worker-0 is assumed 
to be always
@@ -83,6 +93,16 @@ class BcastSessionObj : public SessionObj {
 
   struct Internal;
   friend struct Internal;
+  friend class SocketSessionObj;
+  friend class RemoteSocketSession;
+};
+
+/*!
+ * \brief Managed reference to BcastSessionObj.
+ */
+class BcastSession : public Session {
+ public:
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BcastSession, Session, 
BcastSessionObj);
 };
 
 }  // namespace runtime
diff --git a/src/runtime/disco/disco_worker.cc 
b/src/runtime/disco/disco_worker.cc
index 5e6f401054..4007b104f2 100644
--- a/src/runtime/disco/disco_worker.cc
+++ b/src/runtime/disco/disco_worker.cc
@@ -120,7 +120,7 @@ struct DiscoWorker::Impl {
   }
 
   static void CopyFromWorker0(DiscoWorker* self, int reg_id) {
-    if (self->worker_zero_data != nullptr) {
+    if (self->worker_id == 0) {
       NDArray tgt = GetNDArrayFromHost(self);
       NDArray src = GetReg(self, reg_id);
       tgt.CopyFrom(src);
@@ -128,7 +128,7 @@ struct DiscoWorker::Impl {
   }
 
   static void CopyToWorker0(DiscoWorker* self, int reg_id) {
-    if (self->worker_zero_data != nullptr) {
+    if (self->worker_id == 0) {
       NDArray src = GetNDArrayFromHost(self);
       NDArray tgt = GetReg(self, reg_id);
       tgt.CopyFrom(src);
diff --git a/src/runtime/disco/distributed/socket_session.cc 
b/src/runtime/disco/distributed/socket_session.cc
new file mode 100644
index 0000000000..07196be305
--- /dev/null
+++ b/src/runtime/disco/distributed/socket_session.cc
@@ -0,0 +1,332 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/runtime/registry.h>
+
+#include <numeric>
+
+#include "../../../support/socket.h"
+#include "../bcast_session.h"
+#include "../message_queue.h"
+
+namespace tvm {
+namespace runtime {
+
+using namespace tvm::support;
+
+enum class DiscoSocketAction {
+  kShutdown = static_cast<int>(DiscoAction::kShutDown),
+  kSend,
+  kReceive,
+};
+
+class DiscoSocketChannel : public DiscoChannel {
+ public:
+  explicit DiscoSocketChannel(const TCPSocket& socket)
+      : socket_(socket), message_queue_(&socket_) {}
+
+  DiscoSocketChannel(DiscoSocketChannel&& other) = delete;
+  DiscoSocketChannel(const DiscoSocketChannel& other) = delete;
+  void Send(const TVMArgs& args) { message_queue_.Send(args); }
+  TVMArgs Recv() { return message_queue_.Recv(); }
+  void Reply(const TVMArgs& args) { message_queue_.Send(args); }
+  TVMArgs RecvReply() { return message_queue_.Recv(); }
+
+ private:
+  TCPSocket socket_;
+  DiscoStreamMessageQueue message_queue_;
+};
+
+class SocketSessionObj : public BcastSessionObj {
+ public:
+  explicit SocketSessionObj(int num_nodes, int num_workers_per_node, int 
num_groups,
+                            const String& host, int port)
+      : num_nodes_(num_nodes), num_workers_per_node_(num_workers_per_node) {
+    const PackedFunc* f_create_local_session =
+        Registry::Get("runtime.disco.create_socket_session_local_workers");
+    ICHECK(f_create_local_session != nullptr)
+        << "Cannot find function 
runtime.disco.create_socket_session_local_workers";
+    local_session_ = 
((*f_create_local_session)(num_workers_per_node)).AsObjectRef<BcastSession>();
+    DRef f_init_workers =
+        
local_session_->GetGlobalFunc("runtime.disco.socket_session_init_workers");
+    local_session_->CallPacked(f_init_workers, num_nodes_, /*node_id=*/0, 
num_groups,
+                               num_workers_per_node_);
+
+    Socket::Startup();
+    socket_.Create();
+    socket_.SetKeepAlive(true);
+    socket_.Bind(SockAddr(host.c_str(), port));
+    socket_.Listen();
+    LOG(INFO) << "SocketSession controller listening on " << host << ":" << 
port;
+
+    TVMValue values[4];
+    int type_codes[4];
+    TVMArgsSetter setter(values, type_codes);
+    setter(0, num_nodes);
+    setter(1, num_workers_per_node);
+    setter(2, num_groups);
+
+    for (int i = 0; i + 1 < num_nodes; ++i) {
+      SockAddr addr;
+      remote_sockets_.push_back(socket_.Accept(&addr));
+      
remote_channels_.emplace_back(std::make_unique<DiscoSocketChannel>(remote_sockets_.back()));
+      setter(3, i + 1);
+      // Send metadata to each remote node:
+      //  - num_nodes
+      //  - num_workers_per_node
+      //  - num_groups
+      //  - node_id
+      remote_channels_.back()->Send(TVMArgs(values, type_codes, 4));
+      LOG(INFO) << "Remote node " << addr.AsString() << " connected";
+    }
+  }
+
+  int64_t GetNumWorkers() final { return num_nodes_ * num_workers_per_node_; }
+
+  TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) final {
+    int node_id = worker_id / num_workers_per_node_;
+    if (node_id == 0) {
+      return local_session_->DebugGetFromRemote(reg_id, worker_id);
+    } else {
+      std::vector<TVMValue> values(5);
+      std::vector<int> type_codes(5);
+      PackArgs(values.data(), type_codes.data(), 
static_cast<int>(DiscoSocketAction::kSend),
+               worker_id, static_cast<int>(DiscoAction::kDebugGetFromRemote), 
reg_id, worker_id);
+
+      remote_channels_[node_id - 1]->Send(TVMArgs(values.data(), 
type_codes.data(), values.size()));
+      TVMArgs args = this->RecvReplyPacked(worker_id);
+      ICHECK_EQ(args.size(), 2);
+      ICHECK(static_cast<DiscoAction>(args[0].operator int()) == 
DiscoAction::kDebugGetFromRemote);
+      TVMRetValue result;
+      result = args[1];
+      return result;
+    }
+  }
+
+  void DebugSetRegister(int64_t reg_id, TVMArgValue value, int worker_id) 
final {
+    int node_id = worker_id / num_workers_per_node_;
+    if (node_id == 0) {
+      local_session_->DebugSetRegister(reg_id, value, worker_id);
+    } else {
+      ObjectRef wrapped{nullptr};
+      if (value.type_code() == kTVMNDArrayHandle || value.type_code() == 
kTVMObjectHandle) {
+        wrapped = DiscoDebugObject::Wrap(value);
+        TVMValue tvm_value;
+        int type_code = kTVMObjectHandle;
+        tvm_value.v_handle = const_cast<Object*>(wrapped.get());
+        value = TVMArgValue(tvm_value, type_code);
+      }
+      {
+        TVMValue values[6];
+        int type_codes[6];
+        PackArgs(values, type_codes, 
static_cast<int>(DiscoSocketAction::kSend), worker_id,
+                 static_cast<int>(DiscoAction::kDebugSetRegister), reg_id, 
worker_id, value);
+        remote_channels_[node_id - 1]->Send(TVMArgs(values, type_codes, 6));
+      }
+      TVMRetValue result;
+      TVMArgs args = this->RecvReplyPacked(worker_id);
+      ICHECK_EQ(args.size(), 1);
+      ICHECK(static_cast<DiscoAction>(args[0].operator int()) == 
DiscoAction::kDebugSetRegister);
+    }
+  }
+
+  void BroadcastPacked(const TVMArgs& args) final {
+    local_session_->BroadcastPacked(args);
+    std::vector<TVMValue> values(args.size() + 2);
+    std::vector<int> type_codes(args.size() + 2);
+    PackArgs(values.data(), type_codes.data(), 
static_cast<int>(DiscoSocketAction::kSend), -1);
+    std::copy(args.values, args.values + args.size(), values.begin() + 2);
+    std::copy(args.type_codes, args.type_codes + args.size(), 
type_codes.begin() + 2);
+    for (auto& channel : remote_channels_) {
+      channel->Send(TVMArgs(values.data(), type_codes.data(), values.size()));
+    }
+  }
+
+  void SendPacked(int worker_id, const TVMArgs& args) final {
+    int node_id = worker_id / num_workers_per_node_;
+    if (node_id == 0) {
+      local_session_->SendPacked(worker_id, args);
+      return;
+    }
+    std::vector<TVMValue> values(args.size() + 2);
+    std::vector<int> type_codes(args.size() + 2);
+    PackArgs(values.data(), type_codes.data(), 
static_cast<int>(DiscoSocketAction::kSend),
+             worker_id);
+    std::copy(args.values, args.values + args.size(), values.begin() + 2);
+    std::copy(args.type_codes, args.type_codes + args.size(), 
type_codes.begin() + 2);
+    remote_channels_[node_id - 1]->Send(TVMArgs(values.data(), 
type_codes.data(), values.size()));
+  }
+
+  TVMArgs RecvReplyPacked(int worker_id) final {
+    int node_id = worker_id / num_workers_per_node_;
+    if (node_id == 0) {
+      return local_session_->RecvReplyPacked(worker_id);
+    }
+    TVMValue values[2];
+    int type_codes[2];
+    PackArgs(values, type_codes, 
static_cast<int>(DiscoSocketAction::kReceive), worker_id);
+    remote_channels_[node_id - 1]->Send(TVMArgs(values, type_codes, 2));
+    return remote_channels_[node_id - 1]->Recv();
+  }
+
+  void AppendHostNDArray(const NDArray& host_array) final {
+    local_session_->AppendHostNDArray(host_array);
+  }
+
+  void Shutdown() final {
+    // local session will be implicitly shutdown by its destructor
+    TVMValue values[2];
+    int type_codes[2];
+    PackArgs(values, type_codes, 
static_cast<int>(DiscoSocketAction::kShutdown), -1);
+    for (auto& channel : remote_channels_) {
+      channel->Send(TVMArgs(values, type_codes, 2));
+    }
+    for (auto& socket : remote_sockets_) {
+      socket.Close();
+    }
+    remote_sockets_.clear();
+    remote_channels_.clear();
+    if (!socket_.IsClosed()) {
+      socket_.Close();
+    }
+    Socket::Finalize();
+  }
+
+  ~SocketSessionObj() { Shutdown(); }
+
+  static constexpr const char* _type_key = "runtime.disco.SocketSession";
+  TVM_DECLARE_FINAL_OBJECT_INFO(SocketSessionObj, BcastSessionObj);
+  int num_nodes_;
+  int num_workers_per_node_;
+  TCPSocket socket_;
+  std::vector<TCPSocket> remote_sockets_;
+  std::vector<std::unique_ptr<DiscoSocketChannel>> remote_channels_;
+  BcastSession local_session_{nullptr};
+};
+
+TVM_REGISTER_OBJECT_TYPE(SocketSessionObj);
+
+class RemoteSocketSession {
+ public:
+  explicit RemoteSocketSession(const String& server_host, int server_port, int 
num_local_workers) {
+    socket_.Create();
+    socket_.SetKeepAlive(true);
+    SockAddr server_addr{server_host.c_str(), server_port};
+    Socket::Startup();
+    if (!socket_.Connect(server_addr)) {
+      LOG(FATAL) << "Failed to connect to server " << server_addr.AsString()
+                 << ", errno = " << Socket::GetLastErrorCode();
+    }
+    channel_ = std::make_unique<DiscoSocketChannel>(socket_);
+    TVMArgs metadata = channel_->Recv();
+    ICHECK_EQ(metadata.size(), 4);
+    num_nodes_ = metadata[0].operator int();
+    num_workers_per_node_ = metadata[1].operator int();
+    num_groups_ = metadata[2].operator int();
+    node_id_ = metadata[3].operator int();
+    CHECK_GE(num_local_workers, num_workers_per_node_);
+    InitLocalSession();
+  }
+
+  void MainLoop() {
+    while (true) {
+      TVMArgs args = channel_->Recv();
+      DiscoSocketAction action = 
static_cast<DiscoSocketAction>(args[0].operator int());
+      int worker_id = args[1].operator int();
+      int local_worker_id = worker_id - node_id_ * num_workers_per_node_;
+      switch (action) {
+        case DiscoSocketAction::kSend: {
+          args = TVMArgs(args.values + 2, args.type_codes + 2, args.size() - 
2);
+          if (worker_id == -1) {
+            local_session_->BroadcastPacked(args);
+          } else {
+            local_session_->SendPacked(local_worker_id, args);
+          }
+          break;
+        }
+        case DiscoSocketAction::kReceive: {
+          args = local_session_->RecvReplyPacked(local_worker_id);
+          channel_->Reply(args);
+          break;
+        }
+        case DiscoSocketAction::kShutdown: {
+          local_session_->Shutdown();
+          LOG(INFO) << "Connection closed by remote controller.";
+          return;
+        }
+        default:
+          LOG(FATAL) << "Invalid action " << static_cast<int>(action);
+      }
+    }
+  }
+
+  ~RemoteSocketSession() {
+    socket_.Close();
+    Socket::Finalize();
+  }
+
+ private:
+  void InitLocalSession() {
+    const PackedFunc* f_create_local_session =
+        Registry::Get("runtime.disco.create_socket_session_local_workers");
+    local_session_ = 
((*f_create_local_session)(num_workers_per_node_)).AsObjectRef<BcastSession>();
+
+    DRef f_init_workers =
+        
local_session_->GetGlobalFunc("runtime.disco.socket_session_init_workers");
+    local_session_->CallPacked(f_init_workers, num_nodes_, node_id_, 
num_groups_,
+                               num_workers_per_node_);
+  }
+
+  TCPSocket socket_;
+  BcastSession local_session_{nullptr};
+  std::unique_ptr<DiscoSocketChannel> channel_;
+  int num_nodes_{-1};
+  int node_id_{-1};
+  int num_groups_{-1};
+  int num_workers_per_node_{-1};
+};
+
+void RemoteSocketSessionEntryPoint(const String& server_host, int server_port,
+                                   int num_local_workers) {
+  RemoteSocketSession proxy(server_host, server_port, num_local_workers);
+  proxy.MainLoop();
+}
+
+TVM_REGISTER_GLOBAL("runtime.disco.RemoteSocketSession")
+    .set_body_typed(RemoteSocketSessionEntryPoint);
+
+Session SocketSession(int num_nodes, int num_workers_per_node, int num_groups, 
const String& host,
+                      int port) {
+  auto n = make_object<SocketSessionObj>(num_nodes, num_workers_per_node, 
num_groups, host, port);
+  return Session(n);
+}
+
+TVM_REGISTER_GLOBAL("runtime.disco.SocketSession").set_body_typed(SocketSession);
+
+TVM_REGISTER_GLOBAL("runtime.disco.socket_session_init_workers")
+    .set_body_typed([](int num_nodes, int node_id, int num_groups, int 
num_workers_per_node) {
+      LOG(INFO) << "Initializing worker group with " << num_nodes << " nodes, "
+                << num_workers_per_node << " workers per node, and " << 
num_groups << " groups.";
+      DiscoWorker* worker = DiscoWorker::ThreadLocal();
+      worker->num_groups = num_groups;
+      worker->worker_id = worker->worker_id + node_id * num_workers_per_node;
+      worker->num_workers = num_nodes * num_workers_per_node;
+    });
+
+}  // namespace runtime
+}  // namespace tvm
diff --git a/src/runtime/disco/message_queue.h 
b/src/runtime/disco/message_queue.h
new file mode 100644
index 0000000000..3b78c3e5c1
--- /dev/null
+++ b/src/runtime/disco/message_queue.h
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_RUNTIME_DISCO_MESSAGE_QUEUE_H_
+#define TVM_RUNTIME_DISCO_MESSAGE_QUEUE_H_
+
+#include <dmlc/io.h>
+
+#include <string>
+
+#include "./protocol.h"
+
+namespace tvm {
+namespace runtime {
+
+class DiscoStreamMessageQueue : private dmlc::Stream,
+                                private DiscoProtocol<DiscoStreamMessageQueue> 
{
+ public:
+  explicit DiscoStreamMessageQueue(Stream* stream) : stream_(stream) {}
+
+  ~DiscoStreamMessageQueue() = default;
+
+  void Send(const TVMArgs& args) {
+    RPCReference::ReturnPackedSeq(args.values, args.type_codes, args.num_args, 
this);
+    CommitSendAndNotifyEnqueue();
+  }
+
+  TVMArgs Recv() {
+    bool is_implicit_shutdown = DequeueNextPacket();
+    TVMValue* values = nullptr;
+    int* type_codes = nullptr;
+    int num_args = 0;
+
+    if (is_implicit_shutdown) {
+      num_args = 2;
+      values = ArenaAlloc<TVMValue>(num_args);
+      type_codes = ArenaAlloc<int>(num_args);
+      TVMArgsSetter setter(values, type_codes);
+      setter(0, static_cast<int>(DiscoAction::kShutDown));
+      setter(1, 0);
+    } else {
+      RPCReference::RecvPackedSeq(&values, &type_codes, &num_args, this);
+    }
+    return TVMArgs(values, type_codes, num_args);
+  }
+
+ protected:
+  void CommitSendAndNotifyEnqueue() {
+    stream_->Write(write_buffer_.data(), write_buffer_.size());
+    write_buffer_.clear();
+  }
+
+  /* \brief Read next packet and reset unpacker
+   *
+   * Read the next packet into `read_buffer_`, releasing all arena
+   * allocations performed by the unpacker and resetting the unpacker
+   * to its initial state.
+   *
+   * \return A boolean value.  If true, this packet should be treated
+   *    equivalently to a `DiscoAction::kShutdown` event.  If false,
+   *    this packet should be unpacked.
+   */
+  bool DequeueNextPacket() {
+    uint64_t packet_nbytes = 0;
+    int read_size = stream_->Read(&packet_nbytes, sizeof(packet_nbytes));
+    if (read_size == 0) {
+      // Special case, connection dropped between packets.  Treat as a
+      // request to shutdown.
+      return true;
+    }
+
+    ICHECK_EQ(read_size, sizeof(packet_nbytes))
+        << "Stream closed without proper shutdown. Please make sure to 
explicitly call "
+           "`Session::Shutdown`";
+    read_buffer_.resize(packet_nbytes);
+    read_size = stream_->Read(read_buffer_.data(), packet_nbytes);
+    ICHECK_EQ(read_size, packet_nbytes)
+        << "Stream closed without proper shutdown. Please make sure to 
explicitly call "
+           "`Session::Shutdown`";
+    read_offset_ = 0;
+    this->RecycleAll();
+    RPCCode code = RPCCode::kReturn;
+    this->Read(&code);
+    return false;
+  }
+
+  size_t Read(void* data, size_t size) final {
+    std::memcpy(data, read_buffer_.data() + read_offset_, size);
+    read_offset_ += size;
+    ICHECK_LE(read_offset_, read_buffer_.size());
+    return size;
+  }
+
+  size_t Write(const void* data, size_t size) final {
+    size_t cur_size = write_buffer_.size();
+    write_buffer_.resize(cur_size + size);
+    std::memcpy(write_buffer_.data() + cur_size, data, size);
+    return size;
+  }
+
+  using dmlc::Stream::Read;
+  using dmlc::Stream::ReadArray;
+  using dmlc::Stream::Write;
+  using dmlc::Stream::WriteArray;
+  friend struct RPCReference;
+  friend struct DiscoProtocol<DiscoStreamMessageQueue>;
+
+  // The read/write buffer will only be accessed by the producer thread.
+  std::string write_buffer_;
+  std::string read_buffer_;
+  size_t read_offset_ = 0;
+  dmlc::Stream* stream_;
+};
+
+}  // namespace runtime
+}  // namespace tvm
+
+#endif  // TVM_RUNTIME_DISCO_MESSAGE_QUEUE_H_
diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc
index 35e8fd06b3..d35fc911c6 100644
--- a/src/runtime/disco/nccl/nccl.cc
+++ b/src/runtime/disco/nccl/nccl.cc
@@ -86,7 +86,8 @@ void InitCCLPerWorker(IntTuple device_ids, std::string 
unique_id_bytes) {
                       << "and has not been destructed";
 
   // Step up local context of NCCL
-  int device_id = device_ids[worker->worker_id];
+  int group_size = worker->num_workers / worker->num_groups;
+  int device_id = device_ids[worker->local_worker_id];
   SetDevice(device_id);
 #if TVM_NCCL_RCCL_SWITCH == 0
   StreamCreate(&ctx->default_stream);
@@ -99,7 +100,6 @@ void InitCCLPerWorker(IntTuple device_ids, std::string 
unique_id_bytes) {
   // Initialize the communicator
   ncclUniqueId id;
   std::memcpy(id.internal, unique_id_bytes.data(), NCCL_UNIQUE_ID_BYTES);
-  int group_size = worker->num_workers / worker->num_groups;
   NCCL_CALL(ncclCommInitRank(&ctx->global_comm, worker->num_workers, id, 
worker->worker_id));
   NCCL_CALL(ncclCommSplit(ctx->global_comm, worker->worker_id / group_size,
                           worker->worker_id % group_size, &ctx->group_comm, 
NULL));
diff --git a/src/runtime/disco/process_session.cc 
b/src/runtime/disco/process_session.cc
index 7c8d0796dd..161c3f6e04 100644
--- a/src/runtime/disco/process_session.cc
+++ b/src/runtime/disco/process_session.cc
@@ -31,114 +31,19 @@
 #include "../minrpc/rpc_reference.h"
 #include "./bcast_session.h"
 #include "./disco_worker_thread.h"
+#include "./message_queue.h"
 #include "./protocol.h"
 
 namespace tvm {
 namespace runtime {
 
-class DiscoPipeMessageQueue : private dmlc::Stream, private 
DiscoProtocol<DiscoPipeMessageQueue> {
- public:
-  explicit DiscoPipeMessageQueue(int64_t handle) : pipe_(handle) {}
-
-  ~DiscoPipeMessageQueue() = default;
-
-  void Send(const TVMArgs& args) {
-    RPCReference::ReturnPackedSeq(args.values, args.type_codes, args.num_args, 
this);
-    CommitSendAndNotifyEnqueue();
-  }
-
-  TVMArgs Recv() {
-    bool is_implicit_shutdown = DequeueNextPacket();
-    TVMValue* values = nullptr;
-    int* type_codes = nullptr;
-    int num_args = 0;
-
-    if (is_implicit_shutdown) {
-      num_args = 2;
-      values = ArenaAlloc<TVMValue>(num_args);
-      type_codes = ArenaAlloc<int>(num_args);
-      TVMArgsSetter setter(values, type_codes);
-      setter(0, static_cast<int>(DiscoAction::kShutDown));
-      setter(1, 0);
-    } else {
-      RPCReference::RecvPackedSeq(&values, &type_codes, &num_args, this);
-    }
-    return TVMArgs(values, type_codes, num_args);
-  }
-
- protected:
-  void CommitSendAndNotifyEnqueue() {
-    pipe_.Write(write_buffer_.data(), write_buffer_.size());
-    write_buffer_.clear();
-  }
-
-  /* \brief Read next packet and reset unpacker
-   *
-   * Read the next packet into `read_buffer_`, releasing all arena
-   * allocations performed by the unpacker and resetting the unpacker
-   * to its initial state.
-   *
-   * \return A boolean value.  If true, this packet should be treated
-   *    equivalently to a `DiscoAction::kShutdown` event.  If false,
-   *    this packet should be unpacked.
-   */
-  bool DequeueNextPacket() {
-    uint64_t packet_nbytes = 0;
-    int read_size = pipe_.Read(&packet_nbytes, sizeof(packet_nbytes));
-    if (read_size == 0) {
-      // Special case, connection dropped between packets.  Treat as a
-      // request to shutdown.
-      return true;
-    }
-
-    ICHECK_EQ(read_size, sizeof(packet_nbytes))
-        << "Pipe closed without proper shutdown. Please make sure to 
explicitly call "
-           "`Session::Shutdown`";
-    read_buffer_.resize(packet_nbytes);
-    read_size = pipe_.Read(read_buffer_.data(), packet_nbytes);
-    ICHECK_EQ(read_size, packet_nbytes)
-        << "Pipe closed without proper shutdown. Please make sure to 
explicitly call "
-           "`Session::Shutdown`";
-    read_offset_ = 0;
-    this->RecycleAll();
-    RPCCode code = RPCCode::kReturn;
-    this->Read(&code);
-    return false;
-  }
-
-  size_t Read(void* data, size_t size) final {
-    std::memcpy(data, read_buffer_.data() + read_offset_, size);
-    read_offset_ += size;
-    ICHECK_LE(read_offset_, read_buffer_.size());
-    return size;
-  }
-
-  size_t Write(const void* data, size_t size) final {
-    size_t cur_size = write_buffer_.size();
-    write_buffer_.resize(cur_size + size);
-    std::memcpy(write_buffer_.data() + cur_size, data, size);
-    return size;
-  }
-
-  using dmlc::Stream::Read;
-  using dmlc::Stream::ReadArray;
-  using dmlc::Stream::Write;
-  using dmlc::Stream::WriteArray;
-  friend struct RPCReference;
-  friend struct DiscoProtocol<DiscoPipeMessageQueue>;
-
-  // The read/write buffer will only be accessed by the producer thread.
-  std::string write_buffer_;
-  std::string read_buffer_;
-  size_t read_offset_ = 0;
-  support::Pipe pipe_;
-};
-
 class DiscoProcessChannel final : public DiscoChannel {
  public:
   DiscoProcessChannel(int64_t controler_to_worker_fd, int64_t 
worker_to_controler_fd)
-      : controler_to_worker_(controler_to_worker_fd),
-        worker_to_controler_(worker_to_controler_fd) {}
+      : controller_to_worker_pipe_(controler_to_worker_fd),
+        worker_to_controller_pipe_(worker_to_controler_fd),
+        controler_to_worker_(&controller_to_worker_pipe_),
+        worker_to_controler_(&worker_to_controller_pipe_) {}
 
   DiscoProcessChannel(DiscoProcessChannel&& other) = delete;
   DiscoProcessChannel(const DiscoProcessChannel& other) = delete;
@@ -148,8 +53,10 @@ class DiscoProcessChannel final : public DiscoChannel {
   void Reply(const TVMArgs& args) { worker_to_controler_.Send(args); }
   TVMArgs RecvReply() { return worker_to_controler_.Recv(); }
 
-  DiscoPipeMessageQueue controler_to_worker_;
-  DiscoPipeMessageQueue worker_to_controler_;
+  support::Pipe controller_to_worker_pipe_;
+  support::Pipe worker_to_controller_pipe_;
+  DiscoStreamMessageQueue controler_to_worker_;
+  DiscoStreamMessageQueue worker_to_controler_;
 };
 
 class ProcessSessionObj final : public BcastSessionObj {
@@ -226,7 +133,7 @@ class ProcessSessionObj final : public BcastSessionObj {
       int type_codes[4];
       PackArgs(values, type_codes, 
static_cast<int>(DiscoAction::kDebugSetRegister), reg_id,
                worker_id, value);
-      workers_[worker_id - 1]->Send(TVMArgs(values, type_codes, 4));
+      SendPacked(worker_id, TVMArgs(values, type_codes, 4));
     }
     TVMRetValue result;
     TVMArgs args = this->RecvReplyPacked(worker_id);
@@ -241,6 +148,14 @@ class ProcessSessionObj final : public BcastSessionObj {
     }
   }
 
+  void SendPacked(int worker_id, const TVMArgs& args) final {
+    if (worker_id == 0) {
+      worker_0_->channel->Send(args);
+    } else {
+      workers_.at(worker_id - 1)->Send(args);
+    }
+  }
+
   TVMArgs RecvReplyPacked(int worker_id) final {
     if (worker_id == 0) {
       return worker_0_->channel->RecvReply();
@@ -248,6 +163,13 @@ class ProcessSessionObj final : public BcastSessionObj {
     return this->workers_.at(worker_id - 1)->RecvReply();
   }
 
+  DiscoChannel* GetWorkerChannel(int worker_id) {
+    if (worker_id == 0) {
+      return worker_0_->channel.get();
+    }
+    return workers_.at(worker_id - 1).get();
+  }
+
   PackedFunc process_pool_;
   std::unique_ptr<DiscoWorkerThread> worker_0_;
   std::vector<std::unique_ptr<DiscoProcessChannel>> workers_;
diff --git a/src/runtime/disco/threaded_session.cc 
b/src/runtime/disco/threaded_session.cc
index cc9a311a6b..bf6b6107e1 100644
--- a/src/runtime/disco/threaded_session.cc
+++ b/src/runtime/disco/threaded_session.cc
@@ -173,6 +173,10 @@ class ThreadedSessionObj final : public BcastSessionObj {
     }
   }
 
+  void SendPacked(int worker_id, const TVMArgs& args) final {
+    this->workers_.at(worker_id).channel->Send(args);
+  }
+
   TVMArgs RecvReplyPacked(int worker_id) final {
     return this->workers_.at(worker_id).channel->RecvReply();
   }
diff --git a/src/support/socket.h b/src/support/socket.h
index ac13cd3f2d..032cf257c0 100644
--- a/src/support/socket.h
+++ b/src/support/socket.h
@@ -370,7 +370,7 @@ class Socket {
 /*!
  * \brief a wrapper of TCP socket that hopefully be cross platform
  */
-class TCPSocket : public Socket {
+class TCPSocket : public Socket, public dmlc::Stream {
  public:
   TCPSocket() : Socket(INVALID_SOCKET) {}
   /*!
@@ -552,6 +552,10 @@ class TCPSocket : public Socket {
     ICHECK_EQ(RecvAll(&data[0], datalen), datalen);
     return data;
   }
+
+  size_t Read(void* data, size_t size) final { return Recv(data, size); }
+
+  size_t Write(const void* data, size_t size) final { return Send(data, size); 
}
 };
 
 /*! \brief helper data structure to perform poll */
diff --git a/tests/python/disco/test_session.py 
b/tests/python/disco/test_session.py
index 837b3a14f2..38aa757bf8 100644
--- a/tests/python/disco/test_session.py
+++ b/tests/python/disco/test_session.py
@@ -20,6 +20,9 @@ import tempfile
 
 import numpy as np
 import pytest
+import subprocess
+import threading
+import sys
 
 import tvm
 import tvm.testing
@@ -29,7 +32,7 @@ from tvm.runtime import disco as di
 from tvm.script import ir as I
 from tvm.script import relax as R
 from tvm.script import tir as T
-from tvm.exec import disco_worker as _
+from tvm.exec import disco_worker as _  # pylint: disable=unused-import
 
 
 def _numpy_to_worker_0(sess: di.Session, np_array: np.array, device):
@@ -46,7 +49,75 @@ def _numpy_from_worker_0(sess: di.Session, remote_array, 
shape, dtype):
     return host_array.numpy()
 
 
-_all_session_kinds = [di.ThreadedSession, di.ProcessSession]
+_SOCKET_SESSION_TESTER = None
+
+
+def get_free_port():
+    import socket
+
+    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    s.bind(("", 0))
+    port = s.getsockname()[1]
+    s.close()
+    return port
+
+
+class SocketSessionTester:
+    def __init__(self, num_workers):
+        num_nodes = 2
+        num_groups = 1
+        assert num_workers % num_nodes == 0
+        num_workers_per_node = num_workers // num_nodes
+        server_host = "localhost"
+        server_port = get_free_port()
+        self.sess = None
+
+        def start_server():
+            self.sess = di.SocketSession(
+                num_nodes, num_workers_per_node, num_groups, server_host, 
server_port
+            )
+
+        thread = threading.Thread(target=start_server)
+        thread.start()
+
+        cmd = "tvm.exec.disco_remote_socket_session"
+        self.remote_nodes = []
+        for _ in range(num_nodes - 1):
+            self.remote_nodes.append(
+                subprocess.Popen(
+                    [
+                        "python3",
+                        "-m",
+                        cmd,
+                        server_host,
+                        str(server_port),
+                        str(num_workers_per_node),
+                    ],
+                    stdout=sys.stdout,
+                    stderr=sys.stderr,
+                )
+            )
+
+        thread.join()
+
+    def __del__(self):
+        for node in self.remote_nodes:
+            node.kill()
+        if self.sess is not None:
+            self.sess.shutdown()
+            del self.sess
+
+
+def create_socket_session(num_workers):
+    global _SOCKET_SESSION_TESTER
+    if _SOCKET_SESSION_TESTER is not None:
+        del _SOCKET_SESSION_TESTER
+    _SOCKET_SESSION_TESTER = SocketSessionTester(num_workers)
+    assert _SOCKET_SESSION_TESTER.sess is not None
+    return _SOCKET_SESSION_TESTER.sess
+
+
+_all_session_kinds = [di.ThreadedSession, di.ProcessSession, 
create_socket_session]
 
 
 @pytest.mark.parametrize("session_kind", _all_session_kinds)
@@ -157,6 +228,11 @@ def test_vm_module(session_kind):
         y_nd = _numpy_from_worker_0(sess, y_disc, shape=y_np.shape, 
dtype=y_np.dtype)
         np.testing.assert_equal(y_nd, y_np)
 
+        # sync all workers to make sure the temporary files are cleaned up 
after all workers
+        # finish the execution
+        for i in range(num_workers):
+            sess._sync_worker(i)
+
 
 @pytest.mark.parametrize("session_kind", _all_session_kinds)
 def test_vm_multi_func(session_kind):
@@ -220,10 +296,17 @@ def test_vm_multi_func(session_kind):
         np.testing.assert_equal(y_nd, y_np)
         np.testing.assert_equal(z_nd, x_np)
 
+        # sync all workers to make sure the temporary files are cleaned up 
after all workers
+        # finish the execution
+        for i in range(num_workers):
+            sess._sync_worker(i)
+
 
 @pytest.mark.parametrize("session_kind", _all_session_kinds)
 @pytest.mark.parametrize("num_workers", [1, 2, 4])
 def test_num_workers(session_kind, num_workers):
+    if session_kind == create_socket_session and num_workers < 2:
+        return
     sess = session_kind(num_workers=num_workers)
     assert sess.num_workers == num_workers
 


Reply via email to