This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ae1be53d6d [Disco] Cross-group and p2p send/receive primitives (#17191)
ae1be53d6d is described below

commit ae1be53d6dc08ad8a95ddf6af022880e836e8704
Author: Ruihang Lai <[email protected]>
AuthorDate: Wed Jul 24 08:03:21 2024 -0400

    [Disco] Cross-group and p2p send/receive primitives (#17191)
    
    This PR introduces the disco CCL primitives for cross-group
    and p2p communication.
    
    Specifically, we introduce the send/receive primitives for one group
    to send a buffer to its next group, where every worker in the first
    group sends the buffer to the corresponding worker in the second
    group. The p2p communication refer to the send/receive operations
    to/from a target global worker.
---
 include/tvm/runtime/disco/builtin.h  | 24 ++++++++++
 python/tvm/relax/frontend/nn/core.py |  6 +--
 src/runtime/disco/builtin.cc         | 16 +++++++
 src/runtime/disco/nccl/nccl.cc       | 86 ++++++++++++++++++++++++++++++++++++
 tests/python/disco/test_ccl.py       | 40 ++++++++++++++++-
 5 files changed, 168 insertions(+), 4 deletions(-)

diff --git a/include/tvm/runtime/disco/builtin.h 
b/include/tvm/runtime/disco/builtin.h
index 7d15e35fbd..4453d9737f 100644
--- a/include/tvm/runtime/disco/builtin.h
+++ b/include/tvm/runtime/disco/builtin.h
@@ -114,6 +114,30 @@ TVM_DLL void GatherToWorker0(NDArray send, bool in_group, 
Optional<NDArray> recv
  * \param buffer The buffer to be received
  */
 TVM_DLL void RecvFromWorker0(NDArray buffer);
+/*!
+ * \brief Send a buffer to the corresponding worker in the next group.
+ * An error is thrown if the worker is already in the last group.
+ * \param buffer The sending buffer.
+ */
+TVM_DLL void SendToNextGroup(NDArray buffer);
+/*!
+ * \brief Receive a buffer from the corresponding worker in the previous group.
+ * An error is thrown if the worker is already in the first group.
+ * \param buffer The receiving buffer.
+ */
+TVM_DLL void RecvFromPrevGroup(NDArray buffer);
+/*!
+ * \brief Send a buffer to the target receiver worker (globally across all 
groups).
+ * \param buffer The sending buffer.
+ * \param receiver_id The global receiver worker id.
+ */
+TVM_DLL void SendToWorker(NDArray buffer, int receiver_id);
+/*!
+ * \brief Receive a buffer from the target sender worker (globally across all 
groups).
+ * \param buffer The receiving buffer.
+ * \param sender_id The global sender worker id.
+ */
+TVM_DLL void RecvFromWorker(NDArray buffer, int sender_id);
 /*! \brief Get the local worker id */
 TVM_DLL int WorkerId();
 /*!
diff --git a/python/tvm/relax/frontend/nn/core.py 
b/python/tvm/relax/frontend/nn/core.py
index 46e016a242..3511c38a2b 100644
--- a/python/tvm/relax/frontend/nn/core.py
+++ b/python/tvm/relax/frontend/nn/core.py
@@ -549,16 +549,16 @@ class ModuleList(Module):
     def __iter__(self):
         return iter(self.modules)
 
-    def __getitem__(self, idx):
+    def __getitem__(self, idx: int) -> Module:
         return self.modules[idx]
 
-    def __setitem__(self, idx, module):
+    def __setitem__(self, idx: int, module: Module) -> None:
         self.modules[idx] = module
 
     def __len__(self):
         return len(self.modules)
 
-    def append(self, module):
+    def append(self, module: Module):
         """Add a module to the end of the ModuleList"""
         self.modules.append(module)
 
diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc
index 0cb2ee6f5d..760a330a7a 100644
--- a/src/runtime/disco/builtin.cc
+++ b/src/runtime/disco/builtin.cc
@@ -101,6 +101,18 @@ void GatherToWorker0(NDArray send, bool in_group, 
Optional<NDArray> recv) {
 
 void RecvFromWorker0(NDArray buffer) { 
GetCCLFunc("recv_from_worker0")(buffer); }
 
+void SendToNextGroup(NDArray buffer) { 
GetCCLFunc("send_to_next_group")(buffer); }
+
+void RecvFromPrevGroup(NDArray buffer) { 
GetCCLFunc("recv_from_prev_group")(buffer); }
+
+void SendToWorker(NDArray buffer, int receiver_id) {
+  GetCCLFunc("send_to_worker")(buffer, receiver_id);
+}
+
+void RecvFromWorker(NDArray buffer, int sender_id) {
+  GetCCLFunc("recv_from_worker")(buffer, sender_id);
+}
+
 int WorkerId() { return DiscoWorker::ThreadLocal()->worker_id; }
 
 void SyncWorker() {
@@ -136,6 +148,10 @@ 
TVM_REGISTER_GLOBAL("runtime.disco.broadcast_from_worker0").set_body_typed(Broad
 
TVM_REGISTER_GLOBAL("runtime.disco.scatter_from_worker0").set_body_typed(ScatterFromWorker0);
 
TVM_REGISTER_GLOBAL("runtime.disco.gather_to_worker0").set_body_typed(GatherToWorker0);
 
TVM_REGISTER_GLOBAL("runtime.disco.recv_from_worker0").set_body_typed(RecvFromWorker0);
+TVM_REGISTER_GLOBAL("runtime.disco.send_to_next_group").set_body_typed(SendToNextGroup);
+TVM_REGISTER_GLOBAL("runtime.disco.recv_from_prev_group").set_body_typed(RecvFromPrevGroup);
+TVM_REGISTER_GLOBAL("runtime.disco.send_to_worker").set_body_typed(SendToWorker);
+TVM_REGISTER_GLOBAL("runtime.disco.recv_from_worker").set_body_typed(RecvFromWorker);
 TVM_REGISTER_GLOBAL("runtime.disco.worker_id").set_body_typed([]() -> 
ShapeTuple {
   return ShapeTuple({WorkerId()});
 });
diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc
index 2d2c528b52..35e8fd06b3 100644
--- a/src/runtime/disco/nccl/nccl.cc
+++ b/src/runtime/disco/nccl/nccl.cc
@@ -254,6 +254,57 @@ void RecvFromWorker0(NDArray buffer) {
   NCCL_CALL(ncclGroupEnd());
 }
 
+void SendToNextGroup(NDArray buffer) {
+  CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
+  deviceStream_t stream = ctx->GetDefaultStream();
+  int worker_id = ctx->worker->worker_id;
+  int group_size = ctx->worker->num_workers / ctx->worker->num_groups;
+  int receiver_id = worker_id + group_size;
+  CHECK_LT(receiver_id, ctx->worker->num_workers)
+      << "The current group is already the last group and there is no such a 
next group.";
+  NCCL_CALL(ncclGroupStart());
+  NCCL_CALL(ncclSend(buffer->data, buffer.Shape()->Product(), 
AsNCCLDataType(buffer.DataType()),
+                     receiver_id, ctx->global_comm, stream));
+  NCCL_CALL(ncclGroupEnd());
+}
+
+void RecvFromPrevGroup(NDArray buffer) {
+  CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
+  deviceStream_t stream = ctx->GetDefaultStream();
+  int worker_id = ctx->worker->worker_id;
+  int group_size = ctx->worker->num_workers / ctx->worker->num_groups;
+  int sender_id = worker_id - group_size;
+  CHECK_GE(sender_id, 0)
+      << "The current group is already the first group and there is no such a 
previous group.";
+  NCCL_CALL(ncclGroupStart());
+  NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(), 
AsNCCLDataType(buffer.DataType()),
+                     sender_id, ctx->global_comm, stream));
+  NCCL_CALL(ncclGroupEnd());
+}
+
+void SendToWorker(NDArray buffer, int receiver_id) {
+  CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
+  deviceStream_t stream = ctx->GetDefaultStream();
+  int worker_id = ctx->worker->worker_id;
+  CHECK(receiver_id >= 0 && receiver_id < ctx->worker->num_workers)
+      << "Invalid receiver id " << receiver_id << ". The world size is "
+      << ctx->worker->num_workers;
+  CHECK_NE(worker_id, receiver_id) << "Cannot send to worker itself.";
+  NCCL_CALL(ncclSend(buffer->data, buffer.Shape()->Product(), 
AsNCCLDataType(buffer.DataType()),
+                     receiver_id, ctx->global_comm, stream));
+}
+
+void RecvFromWorker(NDArray buffer, int sender_id) {
+  CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
+  deviceStream_t stream = ctx->GetDefaultStream();
+  int worker_id = ctx->worker->worker_id;
+  CHECK(sender_id >= 0 && sender_id < ctx->worker->num_workers)
+      << "Invalid sender id " << sender_id << ". The world size is " << 
ctx->worker->num_workers;
+  CHECK_NE(worker_id, sender_id) << "Cannot receive from the worker itself.";
+  NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(), 
AsNCCLDataType(buffer.DataType()),
+                     sender_id, ctx->global_comm, stream));
+}
+
 void SyncWorker() {
   CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
   ICHECK(ctx->worker != nullptr);
@@ -284,8 +335,43 @@ TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME 
".gather_to_worker0")
     .set_body_typed(GatherToWorker0);
 TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker0")
     .set_body_typed(RecvFromWorker0);
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_next_group")
+    .set_body_typed(SendToNextGroup);
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME 
".recv_from_prev_group")
+    .set_body_typed(RecvFromPrevGroup);
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_worker")
+    .set_body_typed(SendToWorker);
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker")
+    .set_body_typed(RecvFromWorker);
 TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME 
".sync_worker").set_body_typed(SyncWorker);
 
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME
+                    ".test_send_to_next_group_recv_from_prev_group")
+    .set_body_typed([](NDArray buffer) {
+      CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
+      CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world 
size to be 4.";
+      CHECK_EQ(ctx->worker->num_groups, 2) << "The test requires the group 
size to be 2.";
+      int group_size = ctx->worker->num_workers / ctx->worker->num_groups;
+      int group_id = ctx->worker->worker_id / group_size;
+      if (group_id == 0) {
+        tvm::runtime::nccl::SendToNextGroup(buffer);
+      } else {
+        tvm::runtime::nccl::RecvFromPrevGroup(buffer);
+      }
+    });
+
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME 
".test_worker2_sends_to_worker0")
+    .set_body_typed([](NDArray buffer) {
+      CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
+      CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world 
size to be 4.";
+      CHECK_EQ(ctx->worker->num_groups, 2) << "The test requires the group 
size to be 2.";
+      if (ctx->worker->worker_id == 2) {
+        tvm::runtime::nccl::SendToWorker(buffer, 0);
+      } else if (ctx->worker->worker_id == 0) {
+        tvm::runtime::nccl::RecvFromWorker(buffer, 2);
+      }
+    });
+
 }  // namespace nccl
 }  // namespace runtime
 }  // namespace tvm
diff --git a/tests/python/disco/test_ccl.py b/tests/python/disco/test_ccl.py
index 6c63f64554..c29ece9572 100644
--- a/tests/python/disco/test_ccl.py
+++ b/tests/python/disco/test_ccl.py
@@ -25,11 +25,11 @@ import pytest
 import tvm
 import tvm.testing
 from tvm import dlight as dl
+from tvm import get_global_func
 from tvm import relax as rx
 from tvm.runtime import disco as di
 from tvm.runtime.relax_vm import VirtualMachine
 from tvm.script import relax as R
-from tvm import get_global_func
 
 _all_session_kinds = [di.ThreadedSession, di.ProcessSession]
 _ccl = [get_global_func("runtime.disco.compiled_ccl")()]
@@ -391,6 +391,44 @@ def test_group_gather(session_kind, ccl, capfd):
     ), "No warning messages should be generated from 
disco.Session.gather_to_worker0"
 
 
[email protected]("session_kind", _all_session_kinds)
[email protected]("ccl", _ccl)
+def test_send_to_next_group_receive_from_prev_group(session_kind, ccl):
+    devices = [0, 1, 2, 3]
+    sess = session_kind(num_workers=len(devices), num_groups=2)
+    sess.init_ccl(ccl, *devices)
+
+    array_1 = np.arange(12, dtype="float32").reshape(3, 4)
+    array_2 = np.arange(start=1, stop=-11, step=-1, 
dtype="float32").reshape(3, 4)
+    d_array = sess.empty((3, 4), "float32")
+    d_array.debug_copy_from(0, array_1)
+    d_array.debug_copy_from(1, array_2)
+    sess.get_global_func("runtime.disco." + ccl + 
".test_send_to_next_group_recv_from_prev_group")(
+        d_array
+    )
+
+    result_1 = d_array.debug_get_from_remote(2).numpy()
+    result_2 = d_array.debug_get_from_remote(3).numpy()
+    np.testing.assert_equal(result_1, array_1)
+    np.testing.assert_equal(result_2, array_2)
+
+
[email protected]("session_kind", _all_session_kinds)
[email protected]("ccl", _ccl)
+def test_worker2_send_to_worker0(session_kind, ccl):
+    devices = [0, 1, 2, 3]
+    sess = session_kind(num_workers=len(devices), num_groups=2)
+    sess.init_ccl(ccl, *devices)
+
+    array = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3, 
4)
+    d_array = sess.empty((3, 4), "float32")
+    d_array.debug_copy_from(2, array)
+    sess.get_global_func("runtime.disco." + ccl + 
".test_worker2_sends_to_worker0")(d_array)
+
+    result = d_array.debug_get_from_remote(0).numpy()
+    np.testing.assert_equal(result, array)
+
+
 @pytest.mark.parametrize("session_kind", _all_session_kinds)
 @pytest.mark.parametrize("ccl", _ccl)
 def test_mlp(session_kind, ccl):  # pylint: disable=too-many-locals

Reply via email to