This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 98ba395b6f [Disco][Op] gather_to_worker0 (#15690)
98ba395b6f is described below
commit 98ba395b6f7a83a1e69ff17763c9d80aa1eb1264
Author: Lesheng Jin <[email protected]>
AuthorDate: Thu Sep 7 15:45:14 2023 -0700
[Disco][Op] gather_to_worker0 (#15690)
This pr introduce `gather_to_worker0`, which gathers an array from all
other workers to worker-0.
---
python/tvm/runtime/disco/session.py | 13 ++++++++++++
src/runtime/disco/builtin.cc | 5 +++++
src/runtime/disco/builtin.h | 7 +++++++
src/runtime/disco/nccl/nccl.cc | 41 +++++++++++++++++++++++++++++++++++++
tests/python/disco/test_nccl.py | 21 +++++++++++++++++++
5 files changed, 87 insertions(+)
diff --git a/python/tvm/runtime/disco/session.py
b/python/tvm/runtime/disco/session.py
index 6205a767db..f7ee564360 100644
--- a/python/tvm/runtime/disco/session.py
+++ b/python/tvm/runtime/disco/session.py
@@ -291,6 +291,19 @@ class Session(Object):
func = self._get_cached_method("runtime.disco.scatter_from_worker0")
func(from_array, to_array)
+ def gather_to_worker0(self, from_array: DRef, to_array: DRef) -> None:
+ """Gather an array from all other workers to worker-0.
+
+ Parameters
+ ----------
+ from_array : DRef
+ The array to be gathered from.
+ to_array : DRef
+ The array to be gathered to.
+ """
+ func = self._get_cached_method("runtime.disco.gather_to_worker0")
+ func(from_array, to_array)
+
def allreduce(
self,
array: DRef,
diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc
index 5d5e418303..1698f03dd7 100644
--- a/src/runtime/disco/builtin.cc
+++ b/src/runtime/disco/builtin.cc
@@ -92,6 +92,10 @@ void ScatterFromWorker0(Optional<NDArray> send, NDArray
recv) {
GetCCLFunc("scatter_from_worker0")(send, recv);
}
+void GatherToWorker0(NDArray send, Optional<NDArray> recv) {
+ GetCCLFunc("gather_to_worker0")(send, recv);
+}
+
void RecvFromWorker0(NDArray buffer) {
GetCCLFunc("recv_from_worker0")(buffer); }
int WorkerId() { return DiscoWorker::ThreadLocal()->worker_id; }
@@ -106,6 +110,7 @@ TVM_REGISTER_GLOBAL("runtime.disco.allreduce")
});
TVM_REGISTER_GLOBAL("runtime.disco.broadcast_from_worker0").set_body_typed(BroadcastFromWorker0);
TVM_REGISTER_GLOBAL("runtime.disco.scatter_from_worker0").set_body_typed(ScatterFromWorker0);
+TVM_REGISTER_GLOBAL("runtime.disco.gather_to_worker0").set_body_typed(GatherToWorker0);
TVM_REGISTER_GLOBAL("runtime.disco.recv_from_worker0").set_body_typed(RecvFromWorker0);
TVM_REGISTER_GLOBAL("runtime.disco.worker_id").set_body_typed([]() ->
ShapeTuple {
return ShapeTuple({WorkerId()});
diff --git a/src/runtime/disco/builtin.h b/src/runtime/disco/builtin.h
index 784077e6de..a22540f3c1 100644
--- a/src/runtime/disco/builtin.h
+++ b/src/runtime/disco/builtin.h
@@ -65,6 +65,13 @@ NDArray BroadcastFromWorker0(NDArray buffer);
* \param recv The receiving buffer, which must not be None.
*/
void ScatterFromWorker0(Optional<NDArray> send, NDArray recv);
+/*!
+ * \brief Perform a gather operation to worker-0.
+ * \param send The sending buffer, which must not be None.
+ * \param recv For worker-0, it must be provided, and otherwise, the buffer
must be None. The
+ * receiving buffer will be divided into equal parts and receive from each
worker accordingly.
+ */
+void GatherToWorker0(NDArray send, Optional<NDArray> recv);
/*!
* \brief Receive a buffer from worker-0. No-op if the current worker is
worker-0.
* \param buffer The buffer to be received
diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc
index 52ea8798a0..e5ab3296b8 100644
--- a/src/runtime/disco/nccl/nccl.cc
+++ b/src/runtime/disco/nccl/nccl.cc
@@ -158,6 +158,46 @@ void ScatterFromWorker0(Optional<NDArray> send, NDArray
recv) {
NCCL_CALL(ncclGroupEnd());
}
+void GatherToWorker0(NDArray send, Optional<NDArray> recv) {
+ CHECK(send.defined()) << "ValueError: buffer `send` must not be None";
+ NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
+ int worker_id = ctx->worker->worker_id;
+ int num_workers = ctx->worker->num_workers;
+ if (worker_id == 0) {
+ CHECK(recv.defined()) << "ValueError: buffer `recv` must be provided when
worker_id == 0.";
+ NDArray buffer = recv.value();
+ int64_t numel = buffer.Shape()->Product();
+ CHECK_EQ(numel % num_workers, 0)
+ << "ValueError: Gathering evenly requires that the number of elements
in the buffer to be "
+ "divisible by the number of workers, but got numel = "
+ << numel << " and " << num_workers << " workers.";
+ DataType dtype(buffer->dtype);
+ int64_t numel_per_shard = numel / num_workers;
+ int64_t bytes_per_shard = numel_per_shard * dtype.bytes();
+ CHECK_EQ(numel_per_shard, send.Shape()->Product())
+ << "ValueError: The number of elements in buffer `send` must be the
same as each shard of "
+ "buffer `recv`. `recv.size` is "
+ << numel << ", but `send.size` is " << send.Shape()->Product() << ".";
+ NCCL_CALL(ncclGroupStart());
+ uint8_t* data = static_cast<uint8_t*>(buffer->data);
+ for (int i = 0; i < num_workers; ++i) {
+ NCCL_CALL(ncclRecv(data, numel_per_shard, AsNCCLDataType(dtype), i,
ctx->comm, ctx->stream));
+ data += bytes_per_shard;
+ }
+ } else {
+ if (recv.defined()) {
+ LOG(WARNING) << "ValueError: buffer `recv` must be None when worker_id
!= 0. However, got "
+ "recv = "
+ << recv.get() << ". This will be ignored.";
+ }
+ NCCL_CALL(ncclGroupStart());
+ }
+ int64_t numel = send.Shape()->Product();
+ DataType dtype(send->dtype);
+ NCCL_CALL(ncclSend(send->data, numel, AsNCCLDataType(dtype), 0, ctx->comm,
ctx->stream));
+ NCCL_CALL(ncclGroupEnd());
+}
+
void RecvFromWorker0(NDArray buffer) {
NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
CHECK_NE(ctx->worker->worker_id, 0)
@@ -183,6 +223,7 @@
TVM_REGISTER_GLOBAL("runtime.disco.nccl.allreduce").set_body_typed([](NDArray se
TVM_REGISTER_GLOBAL("runtime.disco.nccl.broadcast_from_worker0")
.set_body_typed(BroadcastFromWorker0);
TVM_REGISTER_GLOBAL("runtime.disco.nccl.scatter_from_worker0").set_body_typed(ScatterFromWorker0);
+TVM_REGISTER_GLOBAL("runtime.disco.nccl.gather_to_worker0").set_body_typed(GatherToWorker0);
TVM_REGISTER_GLOBAL("runtime.disco.nccl.recv_from_worker0").set_body_typed(RecvFromWorker0);
} // namespace nccl
diff --git a/tests/python/disco/test_nccl.py b/tests/python/disco/test_nccl.py
index 8fc56b1ec1..adf507b024 100644
--- a/tests/python/disco/test_nccl.py
+++ b/tests/python/disco/test_nccl.py
@@ -97,6 +97,27 @@ def test_scatter():
)
+def test_gather():
+ num_workers = 2
+ devices = [1, 2]
+ array = np.arange(36, dtype="float32")
+
+ sess = di.ThreadedSession(num_workers=num_workers)
+ sess.init_ccl("nccl", *devices)
+ d_src = sess.empty((3, 3, 2), "float32")
+ d_dst = sess.empty((3, 4, 3), "float32")
+
+ d_src.debug_copy_from(0, array[:18])
+ d_src.debug_copy_from(1, array[18:])
+
+ sess.gather_to_worker0(d_src, d_dst)
+
+ np.testing.assert_equal(
+ d_dst.debug_get_from_remote(0).numpy(),
+ array.reshape(3, 4, 3),
+ )
+
+
def test_mlp(): # pylint: disable=too-many-locals
num_workers = 2
devices = [1, 2]