This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1b6c00d756 [Disco] Implement SocketSession (#17182)
1b6c00d756 is described below
commit 1b6c00d7560afded9b5380abfd3f182461b9448d
Author: Wuwei Lin <[email protected]>
AuthorDate: Thu Jul 25 21:11:33 2024 -0700
[Disco] Implement SocketSession (#17182)
* [Disco] Implement SocketSession
Implements SocketSession that connects multiple local worker
processes/threads over multiple distributed nodes via TCP socket.
* doc
* lint
* resolve conflcit
* lint
* add local worker id
* lint
* lint
* disable for hexagon
* remove from header
---
CMakeLists.txt | 6 +
include/tvm/runtime/disco/disco_worker.h | 4 +
include/tvm/runtime/disco/session.h | 1 +
.../disco_remote_socket_session.py} | 26 +-
python/tvm/runtime/disco/__init__.py | 1 +
python/tvm/runtime/disco/session.py | 23 ++
src/runtime/disco/bcast_session.h | 20 ++
src/runtime/disco/disco_worker.cc | 4 +-
src/runtime/disco/distributed/socket_session.cc | 332 +++++++++++++++++++++
src/runtime/disco/message_queue.h | 133 +++++++++
src/runtime/disco/nccl/nccl.cc | 4 +-
src/runtime/disco/process_session.cc | 128 ++------
src/runtime/disco/threaded_session.cc | 4 +
src/support/socket.h | 6 +-
tests/python/disco/test_session.py | 87 +++++-
15 files changed, 660 insertions(+), 119 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 7575d6c2b4..7fba5355f0 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -387,6 +387,12 @@ if(BUILD_FOR_HEXAGON)
add_definitions(-DDMLC_CXX11_THREAD_LOCAL=0)
endif()
+# distributed disco runtime are disabled for hexagon
+if (NOT BUILD_FOR_HEXAGON)
+ tvm_file_glob(GLOB RUNTIME_DISCO_DISTRIBUTED_SRCS
src/runtime/disco/distributed/*.cc)
+ list(APPEND RUNTIME_SRCS ${RUNTIME_DISCO_DISTRIBUTED_SRCS})
+endif()
+
# Package runtime rules
if(NOT USE_RTTI)
add_definitions(-DDMLC_ENABLE_RTTI=0)
diff --git a/include/tvm/runtime/disco/disco_worker.h
b/include/tvm/runtime/disco/disco_worker.h
index 13f94802c8..c9c85b7dbf 100644
--- a/include/tvm/runtime/disco/disco_worker.h
+++ b/include/tvm/runtime/disco/disco_worker.h
@@ -52,6 +52,7 @@ class DiscoWorker {
explicit DiscoWorker(int worker_id, int num_workers, int num_groups,
WorkerZeroData* worker_zero_data, DiscoChannel* channel)
: worker_id(worker_id),
+ local_worker_id(worker_id),
num_workers(num_workers),
num_groups(num_groups),
default_device(Device{DLDeviceType::kDLCPU, 0}),
@@ -68,6 +69,9 @@ class DiscoWorker {
/*! \brief The id of the worker.*/
int worker_id;
+ /*! \brief The local id of the worker. This can be different from worker_id
if the session is
+ * consisted with multiple sub-sessions. */
+ int local_worker_id;
/*! \brief Total number of workers */
int num_workers;
/*! \brief Total number of workers */
diff --git a/include/tvm/runtime/disco/session.h
b/include/tvm/runtime/disco/session.h
index 97fa79096d..9c34f8a2af 100644
--- a/include/tvm/runtime/disco/session.h
+++ b/include/tvm/runtime/disco/session.h
@@ -281,6 +281,7 @@ class Session : public ObjectRef {
*/
TVM_DLL static Session ProcessSession(int num_workers, int num_groups,
String process_pool_creator, String
entrypoint);
+
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Session, ObjectRef,
SessionObj);
};
diff --git a/python/tvm/runtime/disco/__init__.py
b/python/tvm/exec/disco_remote_socket_session.py
similarity index 58%
copy from python/tvm/runtime/disco/__init__.py
copy to python/tvm/exec/disco_remote_socket_session.py
index 856e69bc35..3111ce30ac 100644
--- a/python/tvm/runtime/disco/__init__.py
+++ b/python/tvm/exec/disco_remote_socket_session.py
@@ -14,12 +14,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""TVM distributed runtime API."""
-from .session import (
- DModule,
- DPackedFunc,
- DRef,
- ProcessSession,
- Session,
- ThreadedSession,
-)
+# pylint: disable=invalid-name
+"""Launch disco session in the remote node and connect to the server."""
+import sys
+import tvm
+from . import disco_worker as _ # pylint: disable=unused-import
+
+
+if __name__ == "__main__":
+ if len(sys.argv) != 4:
+ print("Usage: <server_host> <server_port> <num_workers>")
+ sys.exit(1)
+
+ server_host = sys.argv[1]
+ server_port = int(sys.argv[2])
+ num_workers = int(sys.argv[3])
+ func = tvm.get_global_func("runtime.disco.RemoteSocketSession")
+ func(server_host, server_port, num_workers)
diff --git a/python/tvm/runtime/disco/__init__.py
b/python/tvm/runtime/disco/__init__.py
index 856e69bc35..2ba524cade 100644
--- a/python/tvm/runtime/disco/__init__.py
+++ b/python/tvm/runtime/disco/__init__.py
@@ -22,4 +22,5 @@ from .session import (
ProcessSession,
Session,
ThreadedSession,
+ SocketSession,
)
diff --git a/python/tvm/runtime/disco/session.py
b/python/tvm/runtime/disco/session.py
index 89ef549df3..1749942a9c 100644
--- a/python/tvm/runtime/disco/session.py
+++ b/python/tvm/runtime/disco/session.py
@@ -574,6 +574,29 @@ class ProcessSession(Session):
func(config, os.getpid())
+@register_func("runtime.disco.create_socket_session_local_workers")
+def _create_socket_session_local_workers(num_workers) -> Session:
+ """Create the local session for each distributed node over socket
session."""
+ return ProcessSession(num_workers)
+
+
+@register_object("runtime.disco.SocketSession")
+class SocketSession(Session):
+ """A Disco session backed by socket-based multi-node communication."""
+
+ def __init__(
+ self, num_nodes: int, num_workers_per_node: int, num_groups: int,
host: str, port: int
+ ) -> None:
+ self.__init_handle_by_constructor__(
+ _ffi_api.SocketSession, # type: ignore # pylint: disable=no-member
+ num_nodes,
+ num_workers_per_node,
+ num_groups,
+ host,
+ port,
+ )
+
+
@register_func("runtime.disco._configure_structlog")
def _configure_structlog(pickled_config: bytes, parent_pid: int) -> None:
"""Configure structlog for all disco workers
diff --git a/src/runtime/disco/bcast_session.h
b/src/runtime/disco/bcast_session.h
index 1a4df634b7..0e4ca614d4 100644
--- a/src/runtime/disco/bcast_session.h
+++ b/src/runtime/disco/bcast_session.h
@@ -65,6 +65,16 @@ class BcastSessionObj : public SessionObj {
* \param TVMArgs The input arguments in TVM's PackedFunc calling convention
*/
virtual void BroadcastPacked(const TVMArgs& args) = 0;
+
+ /*!
+ * \brief Send a packed sequence to a worker. This function is usually
called by the controler to
+ * communicate with worker-0, because the worker-0 is assumed to be always
collocated with the
+ * controler. Sending to other workers may not be supported.
+ * \param worker_id The worker id to send the packed sequence to.
+ * \param args The packed sequence to send.
+ */
+ virtual void SendPacked(int worker_id, const TVMArgs& args) = 0;
+
/*!
* \brief Receive a packed sequence from a worker. This function is usually
called by the
* controler to communicate with worker-0, because the worker-0 is assumed
to be always
@@ -83,6 +93,16 @@ class BcastSessionObj : public SessionObj {
struct Internal;
friend struct Internal;
+ friend class SocketSessionObj;
+ friend class RemoteSocketSession;
+};
+
+/*!
+ * \brief Managed reference to BcastSessionObj.
+ */
+class BcastSession : public Session {
+ public:
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BcastSession, Session,
BcastSessionObj);
};
} // namespace runtime
diff --git a/src/runtime/disco/disco_worker.cc
b/src/runtime/disco/disco_worker.cc
index 5e6f401054..4007b104f2 100644
--- a/src/runtime/disco/disco_worker.cc
+++ b/src/runtime/disco/disco_worker.cc
@@ -120,7 +120,7 @@ struct DiscoWorker::Impl {
}
static void CopyFromWorker0(DiscoWorker* self, int reg_id) {
- if (self->worker_zero_data != nullptr) {
+ if (self->worker_id == 0) {
NDArray tgt = GetNDArrayFromHost(self);
NDArray src = GetReg(self, reg_id);
tgt.CopyFrom(src);
@@ -128,7 +128,7 @@ struct DiscoWorker::Impl {
}
static void CopyToWorker0(DiscoWorker* self, int reg_id) {
- if (self->worker_zero_data != nullptr) {
+ if (self->worker_id == 0) {
NDArray src = GetNDArrayFromHost(self);
NDArray tgt = GetReg(self, reg_id);
tgt.CopyFrom(src);
diff --git a/src/runtime/disco/distributed/socket_session.cc
b/src/runtime/disco/distributed/socket_session.cc
new file mode 100644
index 0000000000..07196be305
--- /dev/null
+++ b/src/runtime/disco/distributed/socket_session.cc
@@ -0,0 +1,332 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/runtime/registry.h>
+
+#include <numeric>
+
+#include "../../../support/socket.h"
+#include "../bcast_session.h"
+#include "../message_queue.h"
+
+namespace tvm {
+namespace runtime {
+
+using namespace tvm::support;
+
+enum class DiscoSocketAction {
+ kShutdown = static_cast<int>(DiscoAction::kShutDown),
+ kSend,
+ kReceive,
+};
+
+class DiscoSocketChannel : public DiscoChannel {
+ public:
+ explicit DiscoSocketChannel(const TCPSocket& socket)
+ : socket_(socket), message_queue_(&socket_) {}
+
+ DiscoSocketChannel(DiscoSocketChannel&& other) = delete;
+ DiscoSocketChannel(const DiscoSocketChannel& other) = delete;
+ void Send(const TVMArgs& args) { message_queue_.Send(args); }
+ TVMArgs Recv() { return message_queue_.Recv(); }
+ void Reply(const TVMArgs& args) { message_queue_.Send(args); }
+ TVMArgs RecvReply() { return message_queue_.Recv(); }
+
+ private:
+ TCPSocket socket_;
+ DiscoStreamMessageQueue message_queue_;
+};
+
+class SocketSessionObj : public BcastSessionObj {
+ public:
+ explicit SocketSessionObj(int num_nodes, int num_workers_per_node, int
num_groups,
+ const String& host, int port)
+ : num_nodes_(num_nodes), num_workers_per_node_(num_workers_per_node) {
+ const PackedFunc* f_create_local_session =
+ Registry::Get("runtime.disco.create_socket_session_local_workers");
+ ICHECK(f_create_local_session != nullptr)
+ << "Cannot find function
runtime.disco.create_socket_session_local_workers";
+ local_session_ =
((*f_create_local_session)(num_workers_per_node)).AsObjectRef<BcastSession>();
+ DRef f_init_workers =
+
local_session_->GetGlobalFunc("runtime.disco.socket_session_init_workers");
+ local_session_->CallPacked(f_init_workers, num_nodes_, /*node_id=*/0,
num_groups,
+ num_workers_per_node_);
+
+ Socket::Startup();
+ socket_.Create();
+ socket_.SetKeepAlive(true);
+ socket_.Bind(SockAddr(host.c_str(), port));
+ socket_.Listen();
+ LOG(INFO) << "SocketSession controller listening on " << host << ":" <<
port;
+
+ TVMValue values[4];
+ int type_codes[4];
+ TVMArgsSetter setter(values, type_codes);
+ setter(0, num_nodes);
+ setter(1, num_workers_per_node);
+ setter(2, num_groups);
+
+ for (int i = 0; i + 1 < num_nodes; ++i) {
+ SockAddr addr;
+ remote_sockets_.push_back(socket_.Accept(&addr));
+
remote_channels_.emplace_back(std::make_unique<DiscoSocketChannel>(remote_sockets_.back()));
+ setter(3, i + 1);
+ // Send metadata to each remote node:
+ // - num_nodes
+ // - num_workers_per_node
+ // - num_groups
+ // - node_id
+ remote_channels_.back()->Send(TVMArgs(values, type_codes, 4));
+ LOG(INFO) << "Remote node " << addr.AsString() << " connected";
+ }
+ }
+
+ int64_t GetNumWorkers() final { return num_nodes_ * num_workers_per_node_; }
+
+ TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) final {
+ int node_id = worker_id / num_workers_per_node_;
+ if (node_id == 0) {
+ return local_session_->DebugGetFromRemote(reg_id, worker_id);
+ } else {
+ std::vector<TVMValue> values(5);
+ std::vector<int> type_codes(5);
+ PackArgs(values.data(), type_codes.data(),
static_cast<int>(DiscoSocketAction::kSend),
+ worker_id, static_cast<int>(DiscoAction::kDebugGetFromRemote),
reg_id, worker_id);
+
+ remote_channels_[node_id - 1]->Send(TVMArgs(values.data(),
type_codes.data(), values.size()));
+ TVMArgs args = this->RecvReplyPacked(worker_id);
+ ICHECK_EQ(args.size(), 2);
+ ICHECK(static_cast<DiscoAction>(args[0].operator int()) ==
DiscoAction::kDebugGetFromRemote);
+ TVMRetValue result;
+ result = args[1];
+ return result;
+ }
+ }
+
+ void DebugSetRegister(int64_t reg_id, TVMArgValue value, int worker_id)
final {
+ int node_id = worker_id / num_workers_per_node_;
+ if (node_id == 0) {
+ local_session_->DebugSetRegister(reg_id, value, worker_id);
+ } else {
+ ObjectRef wrapped{nullptr};
+ if (value.type_code() == kTVMNDArrayHandle || value.type_code() ==
kTVMObjectHandle) {
+ wrapped = DiscoDebugObject::Wrap(value);
+ TVMValue tvm_value;
+ int type_code = kTVMObjectHandle;
+ tvm_value.v_handle = const_cast<Object*>(wrapped.get());
+ value = TVMArgValue(tvm_value, type_code);
+ }
+ {
+ TVMValue values[6];
+ int type_codes[6];
+ PackArgs(values, type_codes,
static_cast<int>(DiscoSocketAction::kSend), worker_id,
+ static_cast<int>(DiscoAction::kDebugSetRegister), reg_id,
worker_id, value);
+ remote_channels_[node_id - 1]->Send(TVMArgs(values, type_codes, 6));
+ }
+ TVMRetValue result;
+ TVMArgs args = this->RecvReplyPacked(worker_id);
+ ICHECK_EQ(args.size(), 1);
+ ICHECK(static_cast<DiscoAction>(args[0].operator int()) ==
DiscoAction::kDebugSetRegister);
+ }
+ }
+
+ void BroadcastPacked(const TVMArgs& args) final {
+ local_session_->BroadcastPacked(args);
+ std::vector<TVMValue> values(args.size() + 2);
+ std::vector<int> type_codes(args.size() + 2);
+ PackArgs(values.data(), type_codes.data(),
static_cast<int>(DiscoSocketAction::kSend), -1);
+ std::copy(args.values, args.values + args.size(), values.begin() + 2);
+ std::copy(args.type_codes, args.type_codes + args.size(),
type_codes.begin() + 2);
+ for (auto& channel : remote_channels_) {
+ channel->Send(TVMArgs(values.data(), type_codes.data(), values.size()));
+ }
+ }
+
+ void SendPacked(int worker_id, const TVMArgs& args) final {
+ int node_id = worker_id / num_workers_per_node_;
+ if (node_id == 0) {
+ local_session_->SendPacked(worker_id, args);
+ return;
+ }
+ std::vector<TVMValue> values(args.size() + 2);
+ std::vector<int> type_codes(args.size() + 2);
+ PackArgs(values.data(), type_codes.data(),
static_cast<int>(DiscoSocketAction::kSend),
+ worker_id);
+ std::copy(args.values, args.values + args.size(), values.begin() + 2);
+ std::copy(args.type_codes, args.type_codes + args.size(),
type_codes.begin() + 2);
+ remote_channels_[node_id - 1]->Send(TVMArgs(values.data(),
type_codes.data(), values.size()));
+ }
+
+ TVMArgs RecvReplyPacked(int worker_id) final {
+ int node_id = worker_id / num_workers_per_node_;
+ if (node_id == 0) {
+ return local_session_->RecvReplyPacked(worker_id);
+ }
+ TVMValue values[2];
+ int type_codes[2];
+ PackArgs(values, type_codes,
static_cast<int>(DiscoSocketAction::kReceive), worker_id);
+ remote_channels_[node_id - 1]->Send(TVMArgs(values, type_codes, 2));
+ return remote_channels_[node_id - 1]->Recv();
+ }
+
+ void AppendHostNDArray(const NDArray& host_array) final {
+ local_session_->AppendHostNDArray(host_array);
+ }
+
+ void Shutdown() final {
+ // local session will be implicitly shutdown by its destructor
+ TVMValue values[2];
+ int type_codes[2];
+ PackArgs(values, type_codes,
static_cast<int>(DiscoSocketAction::kShutdown), -1);
+ for (auto& channel : remote_channels_) {
+ channel->Send(TVMArgs(values, type_codes, 2));
+ }
+ for (auto& socket : remote_sockets_) {
+ socket.Close();
+ }
+ remote_sockets_.clear();
+ remote_channels_.clear();
+ if (!socket_.IsClosed()) {
+ socket_.Close();
+ }
+ Socket::Finalize();
+ }
+
+ ~SocketSessionObj() { Shutdown(); }
+
+ static constexpr const char* _type_key = "runtime.disco.SocketSession";
+ TVM_DECLARE_FINAL_OBJECT_INFO(SocketSessionObj, BcastSessionObj);
+ int num_nodes_;
+ int num_workers_per_node_;
+ TCPSocket socket_;
+ std::vector<TCPSocket> remote_sockets_;
+ std::vector<std::unique_ptr<DiscoSocketChannel>> remote_channels_;
+ BcastSession local_session_{nullptr};
+};
+
+TVM_REGISTER_OBJECT_TYPE(SocketSessionObj);
+
+class RemoteSocketSession {
+ public:
+ explicit RemoteSocketSession(const String& server_host, int server_port, int
num_local_workers) {
+ socket_.Create();
+ socket_.SetKeepAlive(true);
+ SockAddr server_addr{server_host.c_str(), server_port};
+ Socket::Startup();
+ if (!socket_.Connect(server_addr)) {
+ LOG(FATAL) << "Failed to connect to server " << server_addr.AsString()
+ << ", errno = " << Socket::GetLastErrorCode();
+ }
+ channel_ = std::make_unique<DiscoSocketChannel>(socket_);
+ TVMArgs metadata = channel_->Recv();
+ ICHECK_EQ(metadata.size(), 4);
+ num_nodes_ = metadata[0].operator int();
+ num_workers_per_node_ = metadata[1].operator int();
+ num_groups_ = metadata[2].operator int();
+ node_id_ = metadata[3].operator int();
+ CHECK_GE(num_local_workers, num_workers_per_node_);
+ InitLocalSession();
+ }
+
+ void MainLoop() {
+ while (true) {
+ TVMArgs args = channel_->Recv();
+ DiscoSocketAction action =
static_cast<DiscoSocketAction>(args[0].operator int());
+ int worker_id = args[1].operator int();
+ int local_worker_id = worker_id - node_id_ * num_workers_per_node_;
+ switch (action) {
+ case DiscoSocketAction::kSend: {
+ args = TVMArgs(args.values + 2, args.type_codes + 2, args.size() -
2);
+ if (worker_id == -1) {
+ local_session_->BroadcastPacked(args);
+ } else {
+ local_session_->SendPacked(local_worker_id, args);
+ }
+ break;
+ }
+ case DiscoSocketAction::kReceive: {
+ args = local_session_->RecvReplyPacked(local_worker_id);
+ channel_->Reply(args);
+ break;
+ }
+ case DiscoSocketAction::kShutdown: {
+ local_session_->Shutdown();
+ LOG(INFO) << "Connection closed by remote controller.";
+ return;
+ }
+ default:
+ LOG(FATAL) << "Invalid action " << static_cast<int>(action);
+ }
+ }
+ }
+
+ ~RemoteSocketSession() {
+ socket_.Close();
+ Socket::Finalize();
+ }
+
+ private:
+ void InitLocalSession() {
+ const PackedFunc* f_create_local_session =
+ Registry::Get("runtime.disco.create_socket_session_local_workers");
+ local_session_ =
((*f_create_local_session)(num_workers_per_node_)).AsObjectRef<BcastSession>();
+
+ DRef f_init_workers =
+
local_session_->GetGlobalFunc("runtime.disco.socket_session_init_workers");
+ local_session_->CallPacked(f_init_workers, num_nodes_, node_id_,
num_groups_,
+ num_workers_per_node_);
+ }
+
+ TCPSocket socket_;
+ BcastSession local_session_{nullptr};
+ std::unique_ptr<DiscoSocketChannel> channel_;
+ int num_nodes_{-1};
+ int node_id_{-1};
+ int num_groups_{-1};
+ int num_workers_per_node_{-1};
+};
+
+void RemoteSocketSessionEntryPoint(const String& server_host, int server_port,
+ int num_local_workers) {
+ RemoteSocketSession proxy(server_host, server_port, num_local_workers);
+ proxy.MainLoop();
+}
+
+TVM_REGISTER_GLOBAL("runtime.disco.RemoteSocketSession")
+ .set_body_typed(RemoteSocketSessionEntryPoint);
+
+Session SocketSession(int num_nodes, int num_workers_per_node, int num_groups,
const String& host,
+ int port) {
+ auto n = make_object<SocketSessionObj>(num_nodes, num_workers_per_node,
num_groups, host, port);
+ return Session(n);
+}
+
+TVM_REGISTER_GLOBAL("runtime.disco.SocketSession").set_body_typed(SocketSession);
+
+TVM_REGISTER_GLOBAL("runtime.disco.socket_session_init_workers")
+ .set_body_typed([](int num_nodes, int node_id, int num_groups, int
num_workers_per_node) {
+ LOG(INFO) << "Initializing worker group with " << num_nodes << " nodes, "
+ << num_workers_per_node << " workers per node, and " <<
num_groups << " groups.";
+ DiscoWorker* worker = DiscoWorker::ThreadLocal();
+ worker->num_groups = num_groups;
+ worker->worker_id = worker->worker_id + node_id * num_workers_per_node;
+ worker->num_workers = num_nodes * num_workers_per_node;
+ });
+
+} // namespace runtime
+} // namespace tvm
diff --git a/src/runtime/disco/message_queue.h
b/src/runtime/disco/message_queue.h
new file mode 100644
index 0000000000..3b78c3e5c1
--- /dev/null
+++ b/src/runtime/disco/message_queue.h
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_RUNTIME_DISCO_MESSAGE_QUEUE_H_
+#define TVM_RUNTIME_DISCO_MESSAGE_QUEUE_H_
+
+#include <dmlc/io.h>
+
+#include <string>
+
+#include "./protocol.h"
+
+namespace tvm {
+namespace runtime {
+
+class DiscoStreamMessageQueue : private dmlc::Stream,
+ private DiscoProtocol<DiscoStreamMessageQueue>
{
+ public:
+ explicit DiscoStreamMessageQueue(Stream* stream) : stream_(stream) {}
+
+ ~DiscoStreamMessageQueue() = default;
+
+ void Send(const TVMArgs& args) {
+ RPCReference::ReturnPackedSeq(args.values, args.type_codes, args.num_args,
this);
+ CommitSendAndNotifyEnqueue();
+ }
+
+ TVMArgs Recv() {
+ bool is_implicit_shutdown = DequeueNextPacket();
+ TVMValue* values = nullptr;
+ int* type_codes = nullptr;
+ int num_args = 0;
+
+ if (is_implicit_shutdown) {
+ num_args = 2;
+ values = ArenaAlloc<TVMValue>(num_args);
+ type_codes = ArenaAlloc<int>(num_args);
+ TVMArgsSetter setter(values, type_codes);
+ setter(0, static_cast<int>(DiscoAction::kShutDown));
+ setter(1, 0);
+ } else {
+ RPCReference::RecvPackedSeq(&values, &type_codes, &num_args, this);
+ }
+ return TVMArgs(values, type_codes, num_args);
+ }
+
+ protected:
+ void CommitSendAndNotifyEnqueue() {
+ stream_->Write(write_buffer_.data(), write_buffer_.size());
+ write_buffer_.clear();
+ }
+
+ /* \brief Read next packet and reset unpacker
+ *
+ * Read the next packet into `read_buffer_`, releasing all arena
+ * allocations performed by the unpacker and resetting the unpacker
+ * to its initial state.
+ *
+ * \return A boolean value. If true, this packet should be treated
+ * equivalently to a `DiscoAction::kShutdown` event. If false,
+ * this packet should be unpacked.
+ */
+ bool DequeueNextPacket() {
+ uint64_t packet_nbytes = 0;
+ int read_size = stream_->Read(&packet_nbytes, sizeof(packet_nbytes));
+ if (read_size == 0) {
+ // Special case, connection dropped between packets. Treat as a
+ // request to shutdown.
+ return true;
+ }
+
+ ICHECK_EQ(read_size, sizeof(packet_nbytes))
+ << "Stream closed without proper shutdown. Please make sure to
explicitly call "
+ "`Session::Shutdown`";
+ read_buffer_.resize(packet_nbytes);
+ read_size = stream_->Read(read_buffer_.data(), packet_nbytes);
+ ICHECK_EQ(read_size, packet_nbytes)
+ << "Stream closed without proper shutdown. Please make sure to
explicitly call "
+ "`Session::Shutdown`";
+ read_offset_ = 0;
+ this->RecycleAll();
+ RPCCode code = RPCCode::kReturn;
+ this->Read(&code);
+ return false;
+ }
+
+ size_t Read(void* data, size_t size) final {
+ std::memcpy(data, read_buffer_.data() + read_offset_, size);
+ read_offset_ += size;
+ ICHECK_LE(read_offset_, read_buffer_.size());
+ return size;
+ }
+
+ size_t Write(const void* data, size_t size) final {
+ size_t cur_size = write_buffer_.size();
+ write_buffer_.resize(cur_size + size);
+ std::memcpy(write_buffer_.data() + cur_size, data, size);
+ return size;
+ }
+
+ using dmlc::Stream::Read;
+ using dmlc::Stream::ReadArray;
+ using dmlc::Stream::Write;
+ using dmlc::Stream::WriteArray;
+ friend struct RPCReference;
+ friend struct DiscoProtocol<DiscoStreamMessageQueue>;
+
+ // The read/write buffer will only be accessed by the producer thread.
+ std::string write_buffer_;
+ std::string read_buffer_;
+ size_t read_offset_ = 0;
+ dmlc::Stream* stream_;
+};
+
+} // namespace runtime
+} // namespace tvm
+
+#endif // TVM_RUNTIME_DISCO_MESSAGE_QUEUE_H_
diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc
index 35e8fd06b3..d35fc911c6 100644
--- a/src/runtime/disco/nccl/nccl.cc
+++ b/src/runtime/disco/nccl/nccl.cc
@@ -86,7 +86,8 @@ void InitCCLPerWorker(IntTuple device_ids, std::string
unique_id_bytes) {
<< "and has not been destructed";
// Step up local context of NCCL
- int device_id = device_ids[worker->worker_id];
+ int group_size = worker->num_workers / worker->num_groups;
+ int device_id = device_ids[worker->local_worker_id];
SetDevice(device_id);
#if TVM_NCCL_RCCL_SWITCH == 0
StreamCreate(&ctx->default_stream);
@@ -99,7 +100,6 @@ void InitCCLPerWorker(IntTuple device_ids, std::string
unique_id_bytes) {
// Initialize the communicator
ncclUniqueId id;
std::memcpy(id.internal, unique_id_bytes.data(), NCCL_UNIQUE_ID_BYTES);
- int group_size = worker->num_workers / worker->num_groups;
NCCL_CALL(ncclCommInitRank(&ctx->global_comm, worker->num_workers, id,
worker->worker_id));
NCCL_CALL(ncclCommSplit(ctx->global_comm, worker->worker_id / group_size,
worker->worker_id % group_size, &ctx->group_comm,
NULL));
diff --git a/src/runtime/disco/process_session.cc
b/src/runtime/disco/process_session.cc
index 7c8d0796dd..161c3f6e04 100644
--- a/src/runtime/disco/process_session.cc
+++ b/src/runtime/disco/process_session.cc
@@ -31,114 +31,19 @@
#include "../minrpc/rpc_reference.h"
#include "./bcast_session.h"
#include "./disco_worker_thread.h"
+#include "./message_queue.h"
#include "./protocol.h"
namespace tvm {
namespace runtime {
-class DiscoPipeMessageQueue : private dmlc::Stream, private
DiscoProtocol<DiscoPipeMessageQueue> {
- public:
- explicit DiscoPipeMessageQueue(int64_t handle) : pipe_(handle) {}
-
- ~DiscoPipeMessageQueue() = default;
-
- void Send(const TVMArgs& args) {
- RPCReference::ReturnPackedSeq(args.values, args.type_codes, args.num_args,
this);
- CommitSendAndNotifyEnqueue();
- }
-
- TVMArgs Recv() {
- bool is_implicit_shutdown = DequeueNextPacket();
- TVMValue* values = nullptr;
- int* type_codes = nullptr;
- int num_args = 0;
-
- if (is_implicit_shutdown) {
- num_args = 2;
- values = ArenaAlloc<TVMValue>(num_args);
- type_codes = ArenaAlloc<int>(num_args);
- TVMArgsSetter setter(values, type_codes);
- setter(0, static_cast<int>(DiscoAction::kShutDown));
- setter(1, 0);
- } else {
- RPCReference::RecvPackedSeq(&values, &type_codes, &num_args, this);
- }
- return TVMArgs(values, type_codes, num_args);
- }
-
- protected:
- void CommitSendAndNotifyEnqueue() {
- pipe_.Write(write_buffer_.data(), write_buffer_.size());
- write_buffer_.clear();
- }
-
- /* \brief Read next packet and reset unpacker
- *
- * Read the next packet into `read_buffer_`, releasing all arena
- * allocations performed by the unpacker and resetting the unpacker
- * to its initial state.
- *
- * \return A boolean value. If true, this packet should be treated
- * equivalently to a `DiscoAction::kShutdown` event. If false,
- * this packet should be unpacked.
- */
- bool DequeueNextPacket() {
- uint64_t packet_nbytes = 0;
- int read_size = pipe_.Read(&packet_nbytes, sizeof(packet_nbytes));
- if (read_size == 0) {
- // Special case, connection dropped between packets. Treat as a
- // request to shutdown.
- return true;
- }
-
- ICHECK_EQ(read_size, sizeof(packet_nbytes))
- << "Pipe closed without proper shutdown. Please make sure to
explicitly call "
- "`Session::Shutdown`";
- read_buffer_.resize(packet_nbytes);
- read_size = pipe_.Read(read_buffer_.data(), packet_nbytes);
- ICHECK_EQ(read_size, packet_nbytes)
- << "Pipe closed without proper shutdown. Please make sure to
explicitly call "
- "`Session::Shutdown`";
- read_offset_ = 0;
- this->RecycleAll();
- RPCCode code = RPCCode::kReturn;
- this->Read(&code);
- return false;
- }
-
- size_t Read(void* data, size_t size) final {
- std::memcpy(data, read_buffer_.data() + read_offset_, size);
- read_offset_ += size;
- ICHECK_LE(read_offset_, read_buffer_.size());
- return size;
- }
-
- size_t Write(const void* data, size_t size) final {
- size_t cur_size = write_buffer_.size();
- write_buffer_.resize(cur_size + size);
- std::memcpy(write_buffer_.data() + cur_size, data, size);
- return size;
- }
-
- using dmlc::Stream::Read;
- using dmlc::Stream::ReadArray;
- using dmlc::Stream::Write;
- using dmlc::Stream::WriteArray;
- friend struct RPCReference;
- friend struct DiscoProtocol<DiscoPipeMessageQueue>;
-
- // The read/write buffer will only be accessed by the producer thread.
- std::string write_buffer_;
- std::string read_buffer_;
- size_t read_offset_ = 0;
- support::Pipe pipe_;
-};
-
class DiscoProcessChannel final : public DiscoChannel {
public:
DiscoProcessChannel(int64_t controler_to_worker_fd, int64_t
worker_to_controler_fd)
- : controler_to_worker_(controler_to_worker_fd),
- worker_to_controler_(worker_to_controler_fd) {}
+ : controller_to_worker_pipe_(controler_to_worker_fd),
+ worker_to_controller_pipe_(worker_to_controler_fd),
+ controler_to_worker_(&controller_to_worker_pipe_),
+ worker_to_controler_(&worker_to_controller_pipe_) {}
DiscoProcessChannel(DiscoProcessChannel&& other) = delete;
DiscoProcessChannel(const DiscoProcessChannel& other) = delete;
@@ -148,8 +53,10 @@ class DiscoProcessChannel final : public DiscoChannel {
void Reply(const TVMArgs& args) { worker_to_controler_.Send(args); }
TVMArgs RecvReply() { return worker_to_controler_.Recv(); }
- DiscoPipeMessageQueue controler_to_worker_;
- DiscoPipeMessageQueue worker_to_controler_;
+ support::Pipe controller_to_worker_pipe_;
+ support::Pipe worker_to_controller_pipe_;
+ DiscoStreamMessageQueue controler_to_worker_;
+ DiscoStreamMessageQueue worker_to_controler_;
};
class ProcessSessionObj final : public BcastSessionObj {
@@ -226,7 +133,7 @@ class ProcessSessionObj final : public BcastSessionObj {
int type_codes[4];
PackArgs(values, type_codes,
static_cast<int>(DiscoAction::kDebugSetRegister), reg_id,
worker_id, value);
- workers_[worker_id - 1]->Send(TVMArgs(values, type_codes, 4));
+ SendPacked(worker_id, TVMArgs(values, type_codes, 4));
}
TVMRetValue result;
TVMArgs args = this->RecvReplyPacked(worker_id);
@@ -241,6 +148,14 @@ class ProcessSessionObj final : public BcastSessionObj {
}
}
+ void SendPacked(int worker_id, const TVMArgs& args) final {
+ if (worker_id == 0) {
+ worker_0_->channel->Send(args);
+ } else {
+ workers_.at(worker_id - 1)->Send(args);
+ }
+ }
+
TVMArgs RecvReplyPacked(int worker_id) final {
if (worker_id == 0) {
return worker_0_->channel->RecvReply();
@@ -248,6 +163,13 @@ class ProcessSessionObj final : public BcastSessionObj {
return this->workers_.at(worker_id - 1)->RecvReply();
}
+ DiscoChannel* GetWorkerChannel(int worker_id) {
+ if (worker_id == 0) {
+ return worker_0_->channel.get();
+ }
+ return workers_.at(worker_id - 1).get();
+ }
+
PackedFunc process_pool_;
std::unique_ptr<DiscoWorkerThread> worker_0_;
std::vector<std::unique_ptr<DiscoProcessChannel>> workers_;
diff --git a/src/runtime/disco/threaded_session.cc
b/src/runtime/disco/threaded_session.cc
index cc9a311a6b..bf6b6107e1 100644
--- a/src/runtime/disco/threaded_session.cc
+++ b/src/runtime/disco/threaded_session.cc
@@ -173,6 +173,10 @@ class ThreadedSessionObj final : public BcastSessionObj {
}
}
+ void SendPacked(int worker_id, const TVMArgs& args) final {
+ this->workers_.at(worker_id).channel->Send(args);
+ }
+
TVMArgs RecvReplyPacked(int worker_id) final {
return this->workers_.at(worker_id).channel->RecvReply();
}
diff --git a/src/support/socket.h b/src/support/socket.h
index ac13cd3f2d..032cf257c0 100644
--- a/src/support/socket.h
+++ b/src/support/socket.h
@@ -370,7 +370,7 @@ class Socket {
/*!
* \brief a wrapper of TCP socket that hopefully be cross platform
*/
-class TCPSocket : public Socket {
+class TCPSocket : public Socket, public dmlc::Stream {
public:
TCPSocket() : Socket(INVALID_SOCKET) {}
/*!
@@ -552,6 +552,10 @@ class TCPSocket : public Socket {
ICHECK_EQ(RecvAll(&data[0], datalen), datalen);
return data;
}
+
+ size_t Read(void* data, size_t size) final { return Recv(data, size); }
+
+ size_t Write(const void* data, size_t size) final { return Send(data, size);
}
};
/*! \brief helper data structure to perform poll */
diff --git a/tests/python/disco/test_session.py
b/tests/python/disco/test_session.py
index 837b3a14f2..38aa757bf8 100644
--- a/tests/python/disco/test_session.py
+++ b/tests/python/disco/test_session.py
@@ -20,6 +20,9 @@ import tempfile
import numpy as np
import pytest
+import subprocess
+import threading
+import sys
import tvm
import tvm.testing
@@ -29,7 +32,7 @@ from tvm.runtime import disco as di
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
-from tvm.exec import disco_worker as _
+from tvm.exec import disco_worker as _ # pylint: disable=unused-import
def _numpy_to_worker_0(sess: di.Session, np_array: np.array, device):
@@ -46,7 +49,75 @@ def _numpy_from_worker_0(sess: di.Session, remote_array,
shape, dtype):
return host_array.numpy()
-_all_session_kinds = [di.ThreadedSession, di.ProcessSession]
+_SOCKET_SESSION_TESTER = None
+
+
+def get_free_port():
+ import socket
+
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ s.bind(("", 0))
+ port = s.getsockname()[1]
+ s.close()
+ return port
+
+
+class SocketSessionTester:
+ def __init__(self, num_workers):
+ num_nodes = 2
+ num_groups = 1
+ assert num_workers % num_nodes == 0
+ num_workers_per_node = num_workers // num_nodes
+ server_host = "localhost"
+ server_port = get_free_port()
+ self.sess = None
+
+ def start_server():
+ self.sess = di.SocketSession(
+ num_nodes, num_workers_per_node, num_groups, server_host,
server_port
+ )
+
+ thread = threading.Thread(target=start_server)
+ thread.start()
+
+ cmd = "tvm.exec.disco_remote_socket_session"
+ self.remote_nodes = []
+ for _ in range(num_nodes - 1):
+ self.remote_nodes.append(
+ subprocess.Popen(
+ [
+ "python3",
+ "-m",
+ cmd,
+ server_host,
+ str(server_port),
+ str(num_workers_per_node),
+ ],
+ stdout=sys.stdout,
+ stderr=sys.stderr,
+ )
+ )
+
+ thread.join()
+
+ def __del__(self):
+ for node in self.remote_nodes:
+ node.kill()
+ if self.sess is not None:
+ self.sess.shutdown()
+ del self.sess
+
+
+def create_socket_session(num_workers):
+ global _SOCKET_SESSION_TESTER
+ if _SOCKET_SESSION_TESTER is not None:
+ del _SOCKET_SESSION_TESTER
+ _SOCKET_SESSION_TESTER = SocketSessionTester(num_workers)
+ assert _SOCKET_SESSION_TESTER.sess is not None
+ return _SOCKET_SESSION_TESTER.sess
+
+
+_all_session_kinds = [di.ThreadedSession, di.ProcessSession,
create_socket_session]
@pytest.mark.parametrize("session_kind", _all_session_kinds)
@@ -157,6 +228,11 @@ def test_vm_module(session_kind):
y_nd = _numpy_from_worker_0(sess, y_disc, shape=y_np.shape,
dtype=y_np.dtype)
np.testing.assert_equal(y_nd, y_np)
+ # sync all workers to make sure the temporary files are cleaned up
after all workers
+ # finish the execution
+ for i in range(num_workers):
+ sess._sync_worker(i)
+
@pytest.mark.parametrize("session_kind", _all_session_kinds)
def test_vm_multi_func(session_kind):
@@ -220,10 +296,17 @@ def test_vm_multi_func(session_kind):
np.testing.assert_equal(y_nd, y_np)
np.testing.assert_equal(z_nd, x_np)
+ # sync all workers to make sure the temporary files are cleaned up
after all workers
+ # finish the execution
+ for i in range(num_workers):
+ sess._sync_worker(i)
+
@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("num_workers", [1, 2, 4])
def test_num_workers(session_kind, num_workers):
+ if session_kind == create_socket_session and num_workers < 2:
+ return
sess = session_kind(num_workers=num_workers)
assert sess.num_workers == num_workers