This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 98ba395b6f [Disco][Op] gather_to_worker0 (#15690)
98ba395b6f is described below

commit 98ba395b6f7a83a1e69ff17763c9d80aa1eb1264
Author: Lesheng Jin <[email protected]>
AuthorDate: Thu Sep 7 15:45:14 2023 -0700

    [Disco][Op] gather_to_worker0 (#15690)
    
    This pr introduce `gather_to_worker0`, which gathers an array from all 
other workers to worker-0.
---
 python/tvm/runtime/disco/session.py | 13 ++++++++++++
 src/runtime/disco/builtin.cc        |  5 +++++
 src/runtime/disco/builtin.h         |  7 +++++++
 src/runtime/disco/nccl/nccl.cc      | 41 +++++++++++++++++++++++++++++++++++++
 tests/python/disco/test_nccl.py     | 21 +++++++++++++++++++
 5 files changed, 87 insertions(+)

diff --git a/python/tvm/runtime/disco/session.py 
b/python/tvm/runtime/disco/session.py
index 6205a767db..f7ee564360 100644
--- a/python/tvm/runtime/disco/session.py
+++ b/python/tvm/runtime/disco/session.py
@@ -291,6 +291,19 @@ class Session(Object):
         func = self._get_cached_method("runtime.disco.scatter_from_worker0")
         func(from_array, to_array)
 
+    def gather_to_worker0(self, from_array: DRef, to_array: DRef) -> None:
+        """Gather an array from all other workers to worker-0.
+
+        Parameters
+        ----------
+        from_array : DRef
+            The array to be gathered from.
+        to_array : DRef
+            The array to be gathered to.
+        """
+        func = self._get_cached_method("runtime.disco.gather_to_worker0")
+        func(from_array, to_array)
+
     def allreduce(
         self,
         array: DRef,
diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc
index 5d5e418303..1698f03dd7 100644
--- a/src/runtime/disco/builtin.cc
+++ b/src/runtime/disco/builtin.cc
@@ -92,6 +92,10 @@ void ScatterFromWorker0(Optional<NDArray> send, NDArray 
recv) {
   GetCCLFunc("scatter_from_worker0")(send, recv);
 }
 
+void GatherToWorker0(NDArray send, Optional<NDArray> recv) {
+  GetCCLFunc("gather_to_worker0")(send, recv);
+}
+
 void RecvFromWorker0(NDArray buffer) { 
GetCCLFunc("recv_from_worker0")(buffer); }
 
 int WorkerId() { return DiscoWorker::ThreadLocal()->worker_id; }
@@ -106,6 +110,7 @@ TVM_REGISTER_GLOBAL("runtime.disco.allreduce")
     });
 
TVM_REGISTER_GLOBAL("runtime.disco.broadcast_from_worker0").set_body_typed(BroadcastFromWorker0);
 
TVM_REGISTER_GLOBAL("runtime.disco.scatter_from_worker0").set_body_typed(ScatterFromWorker0);
+TVM_REGISTER_GLOBAL("runtime.disco.gather_to_worker0").set_body_typed(GatherToWorker0);
 
TVM_REGISTER_GLOBAL("runtime.disco.recv_from_worker0").set_body_typed(RecvFromWorker0);
 TVM_REGISTER_GLOBAL("runtime.disco.worker_id").set_body_typed([]() -> 
ShapeTuple {
   return ShapeTuple({WorkerId()});
diff --git a/src/runtime/disco/builtin.h b/src/runtime/disco/builtin.h
index 784077e6de..a22540f3c1 100644
--- a/src/runtime/disco/builtin.h
+++ b/src/runtime/disco/builtin.h
@@ -65,6 +65,13 @@ NDArray BroadcastFromWorker0(NDArray buffer);
  * \param recv The receiving buffer, which must not be None.
  */
 void ScatterFromWorker0(Optional<NDArray> send, NDArray recv);
+/*!
+ * \brief Perform a gather operation to worker-0.
+ * \param send The sending buffer, which must not be None.
+ * \param recv For worker-0, it must be provided, and otherwise, the buffer 
must be None. The
+ * receiving buffer will be divided into equal parts and receive from each 
worker accordingly.
+ */
+void GatherToWorker0(NDArray send, Optional<NDArray> recv);
 /*!
  * \brief Receive a buffer from worker-0. No-op if the current worker is 
worker-0.
  * \param buffer The buffer to be received
diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc
index 52ea8798a0..e5ab3296b8 100644
--- a/src/runtime/disco/nccl/nccl.cc
+++ b/src/runtime/disco/nccl/nccl.cc
@@ -158,6 +158,46 @@ void ScatterFromWorker0(Optional<NDArray> send, NDArray 
recv) {
   NCCL_CALL(ncclGroupEnd());
 }
 
+void GatherToWorker0(NDArray send, Optional<NDArray> recv) {
+  CHECK(send.defined()) << "ValueError: buffer `send` must not be None";
+  NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
+  int worker_id = ctx->worker->worker_id;
+  int num_workers = ctx->worker->num_workers;
+  if (worker_id == 0) {
+    CHECK(recv.defined()) << "ValueError: buffer `recv` must be provided when 
worker_id == 0.";
+    NDArray buffer = recv.value();
+    int64_t numel = buffer.Shape()->Product();
+    CHECK_EQ(numel % num_workers, 0)
+        << "ValueError: Gathering evenly requires that the number of elements 
in the buffer to be "
+           "divisible by the number of workers, but got numel = "
+        << numel << " and " << num_workers << " workers.";
+    DataType dtype(buffer->dtype);
+    int64_t numel_per_shard = numel / num_workers;
+    int64_t bytes_per_shard = numel_per_shard * dtype.bytes();
+    CHECK_EQ(numel_per_shard, send.Shape()->Product())
+        << "ValueError: The number of elements in buffer `send` must be the 
same as each shard of "
+           "buffer `recv`. `recv.size` is "
+        << numel << ", but `send.size` is " << send.Shape()->Product() << ".";
+    NCCL_CALL(ncclGroupStart());
+    uint8_t* data = static_cast<uint8_t*>(buffer->data);
+    for (int i = 0; i < num_workers; ++i) {
+      NCCL_CALL(ncclRecv(data, numel_per_shard, AsNCCLDataType(dtype), i, 
ctx->comm, ctx->stream));
+      data += bytes_per_shard;
+    }
+  } else {
+    if (recv.defined()) {
+      LOG(WARNING) << "ValueError: buffer `recv` must be None when worker_id 
!= 0. However, got "
+                      "recv = "
+                   << recv.get() << ". This will be ignored.";
+    }
+    NCCL_CALL(ncclGroupStart());
+  }
+  int64_t numel = send.Shape()->Product();
+  DataType dtype(send->dtype);
+  NCCL_CALL(ncclSend(send->data, numel, AsNCCLDataType(dtype), 0, ctx->comm, 
ctx->stream));
+  NCCL_CALL(ncclGroupEnd());
+}
+
 void RecvFromWorker0(NDArray buffer) {
   NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
   CHECK_NE(ctx->worker->worker_id, 0)
@@ -183,6 +223,7 @@ 
TVM_REGISTER_GLOBAL("runtime.disco.nccl.allreduce").set_body_typed([](NDArray se
 TVM_REGISTER_GLOBAL("runtime.disco.nccl.broadcast_from_worker0")
     .set_body_typed(BroadcastFromWorker0);
 
TVM_REGISTER_GLOBAL("runtime.disco.nccl.scatter_from_worker0").set_body_typed(ScatterFromWorker0);
+TVM_REGISTER_GLOBAL("runtime.disco.nccl.gather_to_worker0").set_body_typed(GatherToWorker0);
 
TVM_REGISTER_GLOBAL("runtime.disco.nccl.recv_from_worker0").set_body_typed(RecvFromWorker0);
 
 }  // namespace nccl
diff --git a/tests/python/disco/test_nccl.py b/tests/python/disco/test_nccl.py
index 8fc56b1ec1..adf507b024 100644
--- a/tests/python/disco/test_nccl.py
+++ b/tests/python/disco/test_nccl.py
@@ -97,6 +97,27 @@ def test_scatter():
     )
 
 
+def test_gather():
+    num_workers = 2
+    devices = [1, 2]
+    array = np.arange(36, dtype="float32")
+
+    sess = di.ThreadedSession(num_workers=num_workers)
+    sess.init_ccl("nccl", *devices)
+    d_src = sess.empty((3, 3, 2), "float32")
+    d_dst = sess.empty((3, 4, 3), "float32")
+
+    d_src.debug_copy_from(0, array[:18])
+    d_src.debug_copy_from(1, array[18:])
+
+    sess.gather_to_worker0(d_src, d_dst)
+
+    np.testing.assert_equal(
+        d_dst.debug_get_from_remote(0).numpy(),
+        array.reshape(3, 4, 3),
+    )
+
+
 def test_mlp():  # pylint: disable=too-many-locals
     num_workers = 2
     devices = [1, 2]

Reply via email to