This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ae1be53d6d [Disco] Cross-group and p2p send/receive primitives (#17191)
ae1be53d6d is described below
commit ae1be53d6dc08ad8a95ddf6af022880e836e8704
Author: Ruihang Lai <[email protected]>
AuthorDate: Wed Jul 24 08:03:21 2024 -0400
[Disco] Cross-group and p2p send/receive primitives (#17191)
This PR introduces the disco CCL primitives for cross-group
and p2p communication.
Specifically, we introduce the send/receive primitives for one group
to send a buffer to its next group, where every worker in the first
group sends the buffer to the corresponding worker in the second
group. The p2p communication refer to the send/receive operations
to/from a target global worker.
---
include/tvm/runtime/disco/builtin.h | 24 ++++++++++
python/tvm/relax/frontend/nn/core.py | 6 +--
src/runtime/disco/builtin.cc | 16 +++++++
src/runtime/disco/nccl/nccl.cc | 86 ++++++++++++++++++++++++++++++++++++
tests/python/disco/test_ccl.py | 40 ++++++++++++++++-
5 files changed, 168 insertions(+), 4 deletions(-)
diff --git a/include/tvm/runtime/disco/builtin.h
b/include/tvm/runtime/disco/builtin.h
index 7d15e35fbd..4453d9737f 100644
--- a/include/tvm/runtime/disco/builtin.h
+++ b/include/tvm/runtime/disco/builtin.h
@@ -114,6 +114,30 @@ TVM_DLL void GatherToWorker0(NDArray send, bool in_group,
Optional<NDArray> recv
* \param buffer The buffer to be received
*/
TVM_DLL void RecvFromWorker0(NDArray buffer);
+/*!
+ * \brief Send a buffer to the corresponding worker in the next group.
+ * An error is thrown if the worker is already in the last group.
+ * \param buffer The sending buffer.
+ */
+TVM_DLL void SendToNextGroup(NDArray buffer);
+/*!
+ * \brief Receive a buffer from the corresponding worker in the previous group.
+ * An error is thrown if the worker is already in the first group.
+ * \param buffer The receiving buffer.
+ */
+TVM_DLL void RecvFromPrevGroup(NDArray buffer);
+/*!
+ * \brief Send a buffer to the target receiver worker (globally across all
groups).
+ * \param buffer The sending buffer.
+ * \param receiver_id The global receiver worker id.
+ */
+TVM_DLL void SendToWorker(NDArray buffer, int receiver_id);
+/*!
+ * \brief Receive a buffer from the target sender worker (globally across all
groups).
+ * \param buffer The receiving buffer.
+ * \param sender_id The global sender worker id.
+ */
+TVM_DLL void RecvFromWorker(NDArray buffer, int sender_id);
/*! \brief Get the local worker id */
TVM_DLL int WorkerId();
/*!
diff --git a/python/tvm/relax/frontend/nn/core.py
b/python/tvm/relax/frontend/nn/core.py
index 46e016a242..3511c38a2b 100644
--- a/python/tvm/relax/frontend/nn/core.py
+++ b/python/tvm/relax/frontend/nn/core.py
@@ -549,16 +549,16 @@ class ModuleList(Module):
def __iter__(self):
return iter(self.modules)
- def __getitem__(self, idx):
+ def __getitem__(self, idx: int) -> Module:
return self.modules[idx]
- def __setitem__(self, idx, module):
+ def __setitem__(self, idx: int, module: Module) -> None:
self.modules[idx] = module
def __len__(self):
return len(self.modules)
- def append(self, module):
+ def append(self, module: Module):
"""Add a module to the end of the ModuleList"""
self.modules.append(module)
diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc
index 0cb2ee6f5d..760a330a7a 100644
--- a/src/runtime/disco/builtin.cc
+++ b/src/runtime/disco/builtin.cc
@@ -101,6 +101,18 @@ void GatherToWorker0(NDArray send, bool in_group,
Optional<NDArray> recv) {
void RecvFromWorker0(NDArray buffer) {
GetCCLFunc("recv_from_worker0")(buffer); }
+void SendToNextGroup(NDArray buffer) {
GetCCLFunc("send_to_next_group")(buffer); }
+
+void RecvFromPrevGroup(NDArray buffer) {
GetCCLFunc("recv_from_prev_group")(buffer); }
+
+void SendToWorker(NDArray buffer, int receiver_id) {
+ GetCCLFunc("send_to_worker")(buffer, receiver_id);
+}
+
+void RecvFromWorker(NDArray buffer, int sender_id) {
+ GetCCLFunc("recv_from_worker")(buffer, sender_id);
+}
+
int WorkerId() { return DiscoWorker::ThreadLocal()->worker_id; }
void SyncWorker() {
@@ -136,6 +148,10 @@
TVM_REGISTER_GLOBAL("runtime.disco.broadcast_from_worker0").set_body_typed(Broad
TVM_REGISTER_GLOBAL("runtime.disco.scatter_from_worker0").set_body_typed(ScatterFromWorker0);
TVM_REGISTER_GLOBAL("runtime.disco.gather_to_worker0").set_body_typed(GatherToWorker0);
TVM_REGISTER_GLOBAL("runtime.disco.recv_from_worker0").set_body_typed(RecvFromWorker0);
+TVM_REGISTER_GLOBAL("runtime.disco.send_to_next_group").set_body_typed(SendToNextGroup);
+TVM_REGISTER_GLOBAL("runtime.disco.recv_from_prev_group").set_body_typed(RecvFromPrevGroup);
+TVM_REGISTER_GLOBAL("runtime.disco.send_to_worker").set_body_typed(SendToWorker);
+TVM_REGISTER_GLOBAL("runtime.disco.recv_from_worker").set_body_typed(RecvFromWorker);
TVM_REGISTER_GLOBAL("runtime.disco.worker_id").set_body_typed([]() ->
ShapeTuple {
return ShapeTuple({WorkerId()});
});
diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc
index 2d2c528b52..35e8fd06b3 100644
--- a/src/runtime/disco/nccl/nccl.cc
+++ b/src/runtime/disco/nccl/nccl.cc
@@ -254,6 +254,57 @@ void RecvFromWorker0(NDArray buffer) {
NCCL_CALL(ncclGroupEnd());
}
+void SendToNextGroup(NDArray buffer) {
+ CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
+ deviceStream_t stream = ctx->GetDefaultStream();
+ int worker_id = ctx->worker->worker_id;
+ int group_size = ctx->worker->num_workers / ctx->worker->num_groups;
+ int receiver_id = worker_id + group_size;
+ CHECK_LT(receiver_id, ctx->worker->num_workers)
+ << "The current group is already the last group and there is no such a
next group.";
+ NCCL_CALL(ncclGroupStart());
+ NCCL_CALL(ncclSend(buffer->data, buffer.Shape()->Product(),
AsNCCLDataType(buffer.DataType()),
+ receiver_id, ctx->global_comm, stream));
+ NCCL_CALL(ncclGroupEnd());
+}
+
+void RecvFromPrevGroup(NDArray buffer) {
+ CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
+ deviceStream_t stream = ctx->GetDefaultStream();
+ int worker_id = ctx->worker->worker_id;
+ int group_size = ctx->worker->num_workers / ctx->worker->num_groups;
+ int sender_id = worker_id - group_size;
+ CHECK_GE(sender_id, 0)
+ << "The current group is already the first group and there is no such a
previous group.";
+ NCCL_CALL(ncclGroupStart());
+ NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(),
AsNCCLDataType(buffer.DataType()),
+ sender_id, ctx->global_comm, stream));
+ NCCL_CALL(ncclGroupEnd());
+}
+
+void SendToWorker(NDArray buffer, int receiver_id) {
+ CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
+ deviceStream_t stream = ctx->GetDefaultStream();
+ int worker_id = ctx->worker->worker_id;
+ CHECK(receiver_id >= 0 && receiver_id < ctx->worker->num_workers)
+ << "Invalid receiver id " << receiver_id << ". The world size is "
+ << ctx->worker->num_workers;
+ CHECK_NE(worker_id, receiver_id) << "Cannot send to worker itself.";
+ NCCL_CALL(ncclSend(buffer->data, buffer.Shape()->Product(),
AsNCCLDataType(buffer.DataType()),
+ receiver_id, ctx->global_comm, stream));
+}
+
+void RecvFromWorker(NDArray buffer, int sender_id) {
+ CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
+ deviceStream_t stream = ctx->GetDefaultStream();
+ int worker_id = ctx->worker->worker_id;
+ CHECK(sender_id >= 0 && sender_id < ctx->worker->num_workers)
+ << "Invalid sender id " << sender_id << ". The world size is " <<
ctx->worker->num_workers;
+ CHECK_NE(worker_id, sender_id) << "Cannot receive from the worker itself.";
+ NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(),
AsNCCLDataType(buffer.DataType()),
+ sender_id, ctx->global_comm, stream));
+}
+
void SyncWorker() {
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
ICHECK(ctx->worker != nullptr);
@@ -284,8 +335,43 @@ TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME
".gather_to_worker0")
.set_body_typed(GatherToWorker0);
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker0")
.set_body_typed(RecvFromWorker0);
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_next_group")
+ .set_body_typed(SendToNextGroup);
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME
".recv_from_prev_group")
+ .set_body_typed(RecvFromPrevGroup);
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_worker")
+ .set_body_typed(SendToWorker);
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker")
+ .set_body_typed(RecvFromWorker);
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME
".sync_worker").set_body_typed(SyncWorker);
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME
+ ".test_send_to_next_group_recv_from_prev_group")
+ .set_body_typed([](NDArray buffer) {
+ CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
+ CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world
size to be 4.";
+ CHECK_EQ(ctx->worker->num_groups, 2) << "The test requires the group
size to be 2.";
+ int group_size = ctx->worker->num_workers / ctx->worker->num_groups;
+ int group_id = ctx->worker->worker_id / group_size;
+ if (group_id == 0) {
+ tvm::runtime::nccl::SendToNextGroup(buffer);
+ } else {
+ tvm::runtime::nccl::RecvFromPrevGroup(buffer);
+ }
+ });
+
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME
".test_worker2_sends_to_worker0")
+ .set_body_typed([](NDArray buffer) {
+ CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
+ CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world
size to be 4.";
+ CHECK_EQ(ctx->worker->num_groups, 2) << "The test requires the group
size to be 2.";
+ if (ctx->worker->worker_id == 2) {
+ tvm::runtime::nccl::SendToWorker(buffer, 0);
+ } else if (ctx->worker->worker_id == 0) {
+ tvm::runtime::nccl::RecvFromWorker(buffer, 2);
+ }
+ });
+
} // namespace nccl
} // namespace runtime
} // namespace tvm
diff --git a/tests/python/disco/test_ccl.py b/tests/python/disco/test_ccl.py
index 6c63f64554..c29ece9572 100644
--- a/tests/python/disco/test_ccl.py
+++ b/tests/python/disco/test_ccl.py
@@ -25,11 +25,11 @@ import pytest
import tvm
import tvm.testing
from tvm import dlight as dl
+from tvm import get_global_func
from tvm import relax as rx
from tvm.runtime import disco as di
from tvm.runtime.relax_vm import VirtualMachine
from tvm.script import relax as R
-from tvm import get_global_func
_all_session_kinds = [di.ThreadedSession, di.ProcessSession]
_ccl = [get_global_func("runtime.disco.compiled_ccl")()]
@@ -391,6 +391,44 @@ def test_group_gather(session_kind, ccl, capfd):
), "No warning messages should be generated from
disco.Session.gather_to_worker0"
[email protected]("session_kind", _all_session_kinds)
[email protected]("ccl", _ccl)
+def test_send_to_next_group_receive_from_prev_group(session_kind, ccl):
+ devices = [0, 1, 2, 3]
+ sess = session_kind(num_workers=len(devices), num_groups=2)
+ sess.init_ccl(ccl, *devices)
+
+ array_1 = np.arange(12, dtype="float32").reshape(3, 4)
+ array_2 = np.arange(start=1, stop=-11, step=-1,
dtype="float32").reshape(3, 4)
+ d_array = sess.empty((3, 4), "float32")
+ d_array.debug_copy_from(0, array_1)
+ d_array.debug_copy_from(1, array_2)
+ sess.get_global_func("runtime.disco." + ccl +
".test_send_to_next_group_recv_from_prev_group")(
+ d_array
+ )
+
+ result_1 = d_array.debug_get_from_remote(2).numpy()
+ result_2 = d_array.debug_get_from_remote(3).numpy()
+ np.testing.assert_equal(result_1, array_1)
+ np.testing.assert_equal(result_2, array_2)
+
+
[email protected]("session_kind", _all_session_kinds)
[email protected]("ccl", _ccl)
+def test_worker2_send_to_worker0(session_kind, ccl):
+ devices = [0, 1, 2, 3]
+ sess = session_kind(num_workers=len(devices), num_groups=2)
+ sess.init_ccl(ccl, *devices)
+
+ array = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3,
4)
+ d_array = sess.empty((3, 4), "float32")
+ d_array.debug_copy_from(2, array)
+ sess.get_global_func("runtime.disco." + ccl +
".test_worker2_sends_to_worker0")(d_array)
+
+ result = d_array.debug_get_from_remote(0).numpy()
+ np.testing.assert_equal(result, array)
+
+
@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_mlp(session_kind, ccl): # pylint: disable=too-many-locals