This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new bbc97c77fb [Disco] Group-wise operation (#17180)
bbc97c77fb is described below
commit bbc97c77fbd890361a8705c4450057c5c1bfd0db
Author: Yaxing Cai <[email protected]>
AuthorDate: Tue Jul 23 05:52:57 2024 -0700
[Disco] Group-wise operation (#17180)
This PR introduces the group attribute into Disco, so that group wise
allreduce and allgather is enabled.
---
include/tvm/relax/attrs/ccl.h | 18 +++
include/tvm/runtime/disco/builtin.h | 15 +-
include/tvm/runtime/disco/disco_worker.h | 8 +-
include/tvm/runtime/disco/session.h | 8 +-
python/tvm/exec/disco_worker.py | 15 +-
python/tvm/relax/frontend/nn/op.py | 13 +-
python/tvm/relax/op/ccl/ccl.py | 24 +--
python/tvm/relax/transform/legalize_ops/ccl.py | 10 +-
python/tvm/runtime/disco/process_pool.py | 10 +-
python/tvm/runtime/disco/session.py | 101 +++++++++----
src/relax/op/ccl/ccl.cc | 22 ++-
src/relax/op/ccl/ccl.h | 4 +-
src/runtime/disco/builtin.cc | 34 +++--
src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc | 4 +-
src/runtime/disco/cuda_ipc/custom_allreduce.cc | 4 +-
src/runtime/disco/disco_worker_thread.h | 4 +-
src/runtime/disco/loader.cc | 8 +-
src/runtime/disco/nccl/nccl.cc | 102 ++++++++-----
src/runtime/disco/nccl/nccl_context.h | 13 +-
src/runtime/disco/process_session.cc | 21 ++-
src/runtime/disco/threaded_session.cc | 16 +-
tests/python/disco/test_callback.py | 11 +-
tests/python/disco/test_ccl.py | 168 ++++++++++++++++++++-
tests/python/disco/test_loader.py | 3 +-
tests/python/disco/test_session.py | 20 +--
...ributed_transform_lower_global_to_local_view.py | 4 +-
.../relax/test_transform_legalize_ops_ccl.py | 18 +--
27 files changed, 491 insertions(+), 187 deletions(-)
diff --git a/include/tvm/relax/attrs/ccl.h b/include/tvm/relax/attrs/ccl.h
index 42cec88de6..de043f92be 100644
--- a/include/tvm/relax/attrs/ccl.h
+++ b/include/tvm/relax/attrs/ccl.h
@@ -32,14 +32,32 @@ namespace relax {
/*! \brief Attributes used in allreduce operators */
struct AllReduceAttrs : public tvm::AttrsNode<AllReduceAttrs> {
String op_type;
+ bool in_group;
TVM_DECLARE_ATTRS(AllReduceAttrs, "relax.attrs.AllReduceAttrs") {
TVM_ATTR_FIELD(op_type).describe(
"The type of reduction operation to be applied to the input data. Now
only sum is "
"supported.");
+ TVM_ATTR_FIELD(in_group).describe(
+ "Whether the reduction operation performs in group or globally or in
group as default.");
}
}; // struct AllReduceAttrs
+/*! \brief Attributes used in allgather operators */
+struct AllGatherAttrs : public tvm::AttrsNode<AllGatherAttrs> {
+ int num_workers;
+ bool in_group;
+
+ TVM_DECLARE_ATTRS(AllGatherAttrs, "relax.attrs.AllGatherAttrs") {
+ TVM_ATTR_FIELD(num_workers)
+ .describe(
+ "The number of workers, also the number of parts the given buffer
should be chunked "
+ "into.");
+ TVM_ATTR_FIELD(in_group).describe(
+ "Whether the allgather operation performs in group or globally or in
group as default.");
+ }
+}; // struct AllGatherAttrs
+
/*! \brief Attributes used in scatter operators */
struct ScatterCollectiveAttrs : public tvm::AttrsNode<ScatterCollectiveAttrs> {
int num_workers;
diff --git a/include/tvm/runtime/disco/builtin.h
b/include/tvm/runtime/disco/builtin.h
index cf9967dbfe..7d15e35fbd 100644
--- a/include/tvm/runtime/disco/builtin.h
+++ b/include/tvm/runtime/disco/builtin.h
@@ -75,35 +75,40 @@ TVM_DLL NDArray DiscoEmptyNDArray(ShapeTuple shape,
DataType dtype, Device devic
* \brief Perform an allreduce operation using the underlying communication
library
* \param send The array send to perform allreduce on
* \param reduce_kind The kind of reduction operation (e.g. sum, avg, min, max)
+ * \param in_group Whether the allreduce operation performs globally or in
group as default.
* \param recv The array receives the outcome of allreduce
*/
-TVM_DLL void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv);
+TVM_DLL void AllReduce(NDArray send, ReduceKind reduce_kind, bool in_group,
NDArray recv);
/*!
* \brief Perform an allgather operation using the underlying communication
library
* \param send The array send to perform allgather on
+ * \param in_group Whether the allgather operation performs globally or in
group as default.
* \param recv The array receives the outcome of allgather
*/
-TVM_DLL void AllGather(NDArray send, NDArray recv);
+TVM_DLL void AllGather(NDArray send, bool in_group, NDArray recv);
/*!
* \brief Perform a broadcast operation from worker-0
* \param send The buffer to be broadcasted
+ * \param in_group Whether the broadcast operation performs globally or in
group as default.
* \param recv The buffer receives the broadcasted array
*/
-TVM_DLL void BroadcastFromWorker0(NDArray send, NDArray recv);
+TVM_DLL void BroadcastFromWorker0(NDArray send, bool in_group, NDArray recv);
/*!
* \brief Perform a scatter operation from worker-0, chunking the given buffer
into equal parts.
* \param send For worker-0, it must be provided, and otherwise, the buffer
must be None.
* The buffer will be divided into equal parts and sent to each worker
accordingly.
+ * \param in_group Whether the scatter operation performs globally or in group
as default.
* \param recv The receiving buffer, which must not be None.
*/
-TVM_DLL void ScatterFromWorker0(Optional<NDArray> send, NDArray recv);
+TVM_DLL void ScatterFromWorker0(Optional<NDArray> send, bool in_group, NDArray
recv);
/*!
* \brief Perform a gather operation to worker-0.
* \param send The sending buffer, which must not be None.
+ * \param in_group Whether the gather operation performs globally or in group
as default.
* \param recv For worker-0, it must be provided, and otherwise, the buffer
must be None. The
* receiving buffer will be divided into equal parts and receive from each
worker accordingly.
*/
-TVM_DLL void GatherToWorker0(NDArray send, Optional<NDArray> recv);
+TVM_DLL void GatherToWorker0(NDArray send, bool in_group, Optional<NDArray>
recv);
/*!
* \brief Receive a buffer from worker-0. No-op if the current worker is
worker-0.
* \param buffer The buffer to be received
diff --git a/include/tvm/runtime/disco/disco_worker.h
b/include/tvm/runtime/disco/disco_worker.h
index 14f8f23807..301b5b8d62 100644
--- a/include/tvm/runtime/disco/disco_worker.h
+++ b/include/tvm/runtime/disco/disco_worker.h
@@ -44,14 +44,16 @@ class DiscoWorker {
* \brief Construct a worker.
* \param worker_id The id of the worker.
* \param num_workers The number of the workers.
+ * \param num_groups The number of the worker groups.
* \param worker_zero_data The data shared between worker-0 and the
controler. It's a nullptr if
* the worker is not worker-0.
* \param channel The communication channel between the worker and the
controler.
*/
- explicit DiscoWorker(int worker_id, int num_workers, WorkerZeroData*
worker_zero_data,
- DiscoChannel* channel)
+ explicit DiscoWorker(int worker_id, int num_workers, int num_groups,
+ WorkerZeroData* worker_zero_data, DiscoChannel* channel)
: worker_id(worker_id),
num_workers(num_workers),
+ num_groups(num_groups),
default_device(Device{DLDeviceType::kDLCPU, 0}),
worker_zero_data(worker_zero_data),
channel(channel),
@@ -68,6 +70,8 @@ class DiscoWorker {
int worker_id;
/*! \brief Total number of workers */
int num_workers;
+ /*! \brief Total number of workers */
+ int num_groups;
/*! \brief The default device to allocate data if not specified */
Device default_device;
/*! \brief The name of the underlying collective communication library. */
diff --git a/include/tvm/runtime/disco/session.h
b/include/tvm/runtime/disco/session.h
index 71fcce75b2..97fa79096d 100644
--- a/include/tvm/runtime/disco/session.h
+++ b/include/tvm/runtime/disco/session.h
@@ -264,11 +264,13 @@ class Session : public ObjectRef {
/*!
* \brief Create a session backed by a thread pool of workers
* \param num_workers The number of workers.
+ * \param num_groups The number of worker groups.
*/
- TVM_DLL static Session ThreadedSession(int num_workers);
+ TVM_DLL static Session ThreadedSession(int num_workers, int num_groups);
/*!
* \brief Create a session backed by pipe-based multiprocessing
* \param num_workers The number of workers.
+ * \param num_groups The number of worker groups.
* \param process_pool_creator The name of a global function that takes
`num_workers` as an input,
* and returns a PackedFunc, which takes an integer `worker_id` as the input
and returns None.
* When `worker-id` is 0, it shuts down the process pool; Otherwise, it
retursn a tuple
@@ -277,8 +279,8 @@ class Session : public ObjectRef {
* \note Worker-0 is always co-located with the controler as a separate
thread, and therefore
* worker-0 does not exist in the process pool.
*/
- TVM_DLL static Session ProcessSession(int num_workers, String
process_pool_creator,
- String entrypoint);
+ TVM_DLL static Session ProcessSession(int num_workers, int num_groups,
+ String process_pool_creator, String
entrypoint);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Session, ObjectRef,
SessionObj);
};
diff --git a/python/tvm/exec/disco_worker.py b/python/tvm/exec/disco_worker.py
index 76ce0ff993..b1f1554b56 100644
--- a/python/tvm/exec/disco_worker.py
+++ b/python/tvm/exec/disco_worker.py
@@ -99,22 +99,23 @@ def _make_callback(device: tvm.runtime.Device) ->
Callable[[str, int], NDArray]:
def main():
"""Main worker function"""
- if len(sys.argv) != 5:
- print("Usage: <worker_id> <num_workers> <read_fd> <write_fd>")
+ if len(sys.argv) != 6:
+ print("Usage: <worker_id> <num_workers> <num_groups> <read_fd>
<write_fd>")
return
worker_id = int(sys.argv[1])
num_workers = int(sys.argv[2])
+ num_groups = int(sys.argv[3])
if sys.platform == "win32":
import msvcrt # pylint: disable=import-outside-toplevel,import-error
- reader = msvcrt.open_osfhandle(int(sys.argv[3]), os.O_BINARY)
- writer = msvcrt.open_osfhandle(int(sys.argv[4]), os.O_BINARY)
+ reader = msvcrt.open_osfhandle(int(sys.argv[4]), os.O_BINARY)
+ writer = msvcrt.open_osfhandle(int(sys.argv[5]), os.O_BINARY)
else:
- reader = int(sys.argv[3])
- writer = int(sys.argv[4])
+ reader = int(sys.argv[4])
+ writer = int(sys.argv[5])
worker_func = get_global_func("runtime.disco.WorkerProcess")
- worker_func(worker_id, num_workers, reader, writer)
+ worker_func(worker_id, num_workers, num_groups, reader, writer)
if __name__ == "__main__":
diff --git a/python/tvm/relax/frontend/nn/op.py
b/python/tvm/relax/frontend/nn/op.py
index ec072f663c..e1ba4483c7 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -1671,16 +1671,21 @@ def interpolate(
)
-def ccl_allreduce(x: Tensor, op_type: str = "sum", name="ccl_allreduce"):
+def ccl_allreduce(x: Tensor, op_type: str = "sum", in_group: bool = True,
name="ccl_allreduce"):
"""CCL Allreduce operator
Parameters
----------
- x : Tensor
+ x : relax.Expr
The input tensor.
- op_type: str
+
+ op_type : str
The type of reduction operation to be applied to the input data.
Now "sum", "prod", "min", "max" and "avg" are supported.
+
+ in_group : bool
+ Whether the reduction operation performs globally or in group as default.
+
name : str
Name hint for this operation.
@@ -1689,7 +1694,7 @@ def ccl_allreduce(x: Tensor, op_type: str = "sum",
name="ccl_allreduce"):
result : Tensor
The result tensor of allreduce.
"""
- return wrap_nested(_op.ccl.allreduce(x._expr, op_type), name)
+ return wrap_nested(_op.ccl.allreduce(x._expr, op_type, in_group), name)
def ccl_broadcast_from_worker0(x: Tensor, name="broadcast_from_worker"):
diff --git a/python/tvm/relax/op/ccl/ccl.py b/python/tvm/relax/op/ccl/ccl.py
index 21c7946120..982c048021 100644
--- a/python/tvm/relax/op/ccl/ccl.py
+++ b/python/tvm/relax/op/ccl/ccl.py
@@ -15,25 +15,26 @@
# specific language governing permissions and limitations
# under the License.
"""Relax Collective Communications Library (CCL) operators"""
-from typing import Union
-from tvm.relax import PrimValue
from . import _ffi_api
from ...expr import Expr
-from ....ir import PrimExpr
-def allreduce(x, op_type: str = "sum"): # pylint: disable=invalid-name
+def allreduce(x, op_type: str = "sum", in_group: bool = True): # pylint:
disable=invalid-name
"""Allreduce operator
Parameters
----------
x : relax.Expr
The input tensor.
- op_type: str
+
+ op_type : str
The type of reduction operation to be applied to the input data.
Now "sum", "prod", "min", "max" and "avg" are supported.
+ in_group : bool
+ Whether the reduction operation performs globally or in group as default.
+
Returns
-------
result : relax.Expr
@@ -44,10 +45,10 @@ def allreduce(x, op_type: str = "sum"): # pylint:
disable=invalid-name
"Allreduce only supports limited reduction operations, "
f"including {supported_op_types}, but got {op_type}."
)
- return _ffi_api.allreduce(x, op_type) # type: ignore # pylint:
disable=no-member
+ return _ffi_api.allreduce(x, op_type, in_group) # type: ignore # pylint:
disable=no-member
-def allgather(x, num_workers: Union[int, PrimExpr, PrimValue]): # pylint:
disable=invalid-name
+def allgather(x, num_workers: int, in_group: bool = True): # pylint:
disable=invalid-name
"""AllGather operator
Parameters
@@ -55,17 +56,18 @@ def allgather(x, num_workers: Union[int, PrimExpr,
PrimValue]): # pylint: disab
x : relax.Expr
The input tensor.
- num_worker : Union[int, PrimExpr, PrimValue]
+ num_worker : int
The number of workers to gather data from.
+ in_group : bool
+ Whether the gather operation performs globally or in group as default.
+
Returns
-------
result : relax.Expr
The result of allgather.
"""
- if not isinstance(num_workers, PrimValue):
- num_workers = PrimValue(num_workers)
- return _ffi_api.allgather(x, num_workers) # type: ignore # pylint:
disable=no-member
+ return _ffi_api.allgather(x, num_workers, in_group) # type: ignore #
pylint: disable=no-member
def broadcast_from_worker0(x: Expr) -> Expr:
diff --git a/python/tvm/relax/transform/legalize_ops/ccl.py
b/python/tvm/relax/transform/legalize_ops/ccl.py
index ae0be3c228..364dee750e 100644
--- a/python/tvm/relax/transform/legalize_ops/ccl.py
+++ b/python/tvm/relax/transform/legalize_ops/ccl.py
@@ -41,7 +41,7 @@ def _allreduce(_bb: BlockBuilder, call: Call) -> Expr:
)
return call_dps_packed(
"runtime.disco.allreduce",
- [call.args[0], ShapeExpr([op_type_map[op_type_str]])],
+ [call.args[0], ShapeExpr([op_type_map[op_type_str]]),
call.attrs.in_group],
out_sinfo=call.args[0].struct_info,
)
@@ -57,12 +57,12 @@ def _allgather(_bb: BlockBuilder, call: Call) -> Expr:
arg_shape = arg_sinfo.shape.struct_info
for i, shape_value in enumerate(arg_shape.values):
if i == 0:
- output_shape.append(shape_value * call.args[1].value)
+ output_shape.append(shape_value * call.attrs.num_workers)
else:
output_shape.append(shape_value)
return call_dps_packed(
"runtime.disco.allgather",
- call.args[0],
+ [call.args[0], call.attrs.in_group],
out_sinfo=TensorStructInfo(
shape=output_shape,
dtype=arg_sinfo.dtype,
@@ -75,7 +75,7 @@ def _allgather(_bb: BlockBuilder, call: Call) -> Expr:
def _broadcast_from_worker0(_bb: BlockBuilder, call: Call) -> Expr:
return call_dps_packed(
"runtime.disco.broadcast_from_worker0",
- call.args[0],
+ [call.args[0], False],
out_sinfo=call.args[0].struct_info,
)
@@ -116,7 +116,7 @@ def _scatter_from_worker0(_bb: BlockBuilder, call: Call) ->
Expr:
output_shape = output_shape[1:]
return call_dps_packed(
"runtime.disco.scatter_from_worker0",
- transpose_var,
+ [transpose_var, False],
out_sinfo=TensorStructInfo(
shape=output_shape,
dtype=call.args[0].struct_info.dtype,
diff --git a/python/tvm/runtime/disco/process_pool.py
b/python/tvm/runtime/disco/process_pool.py
index 1ad8659d60..95969e038e 100644
--- a/python/tvm/runtime/disco/process_pool.py
+++ b/python/tvm/runtime/disco/process_pool.py
@@ -38,6 +38,9 @@ class DiscoPopenWorker:
num_workers : int
The total number of workers.
+ num_groups : int
+ The total number of worker groups.
+
stdout: Union[None, int, IO[Any]]
The standard output streams handler specified for the popen process.
@@ -49,12 +52,14 @@ class DiscoPopenWorker:
self,
worker_id: int,
num_workers: int,
+ num_groups: int,
entrypoint: str = "tvm.exec.disco_worker",
stdout=None,
stderr=None,
):
self.worker_id = worker_id
self.num_workers = num_workers
+ self.num_groups = num_groups
self.entrypoint = entrypoint
self._proc = None
self._stdout = stdout
@@ -118,6 +123,7 @@ class DiscoPopenWorker:
self.entrypoint,
str(self.worker_id),
str(self.num_workers),
+ str(self.num_groups),
]
if sys.platform == "win32":
import msvcrt # pylint:
disable=import-error,import-outside-toplevel
@@ -172,9 +178,9 @@ def _kill_child_processes(pid):
@register_func("runtime.disco.create_process_pool")
-def _create_process_pool(num_workers: int, entrypoint: str):
+def _create_process_pool(num_workers: int, num_groups: int, entrypoint: str):
"""Create a process pool where the workers' are [1, num_workers)."""
- pool = [DiscoPopenWorker(i, num_workers, entrypoint) for i in range(1,
num_workers)]
+ pool = [DiscoPopenWorker(i, num_workers, num_groups, entrypoint) for i in
range(1, num_workers)]
def result_func(worker_id: int):
nonlocal pool
diff --git a/python/tvm/runtime/disco/session.py
b/python/tvm/runtime/disco/session.py
index ddde1bc1f3..38c4f2a235 100644
--- a/python/tvm/runtime/disco/session.py
+++ b/python/tvm/runtime/disco/session.py
@@ -66,6 +66,7 @@ class DRef(Object):
----------
worker_id : int
The id of the worker to be copied to.
+
value : Union[numpy.ndarray, NDArray]
The value to be copied.
"""
@@ -121,6 +122,7 @@ class Session(Object):
dtype: str,
device: Optional[Device] = None,
worker0_only: bool = False,
+ in_group: bool = True,
) -> DRef:
"""Create an empty NDArray on all workers and attach them to a DRef.
@@ -139,6 +141,11 @@ class Session(Object):
If False (default), allocate an array on each worker. If
True, only allocate an array on worker0.
+ in_group: bool
+ Take effective when `worker0_only` is True. If True (default),
+ allocate an array on each first worker in each group. If
+ False, only allocate an array on worker0 globally.
+
Returns
-------
array : DRef
@@ -148,7 +155,7 @@ class Session(Object):
if device is None:
device = Device(device_type=0, device_id=0)
func = self._get_cached_method("runtime.disco.empty")
- return func(ShapeTuple(shape), dtype, device, worker0_only)
+ return func(ShapeTuple(shape), dtype, device, worker0_only, in_group)
def shutdown(self):
"""Shut down the Disco session"""
@@ -244,6 +251,7 @@ class Session(Object):
----------
host_array : numpy.ndarray
The array to be copied to worker-0.
+
remote_array : NDArray
The NDArray on worker-0.
"""
@@ -255,11 +263,9 @@ class Session(Object):
Parameters
----------
host_array : NDArray
-
The array to be copied to worker-0.
remote_array : Optiona[DRef]
-
The destination NDArray on worker-0.
Returns
@@ -289,6 +295,7 @@ class Session(Object):
----------
path : str
The path to the VM module file.
+
device : Optional[Device] = None
The device to load the VM module to. Default to the default device
of each worker.
@@ -312,6 +319,7 @@ class Session(Object):
- nccl
- rccl
- mpi
+
*device_ids : int
The device IDs to be used by the underlying communication library.
"""
@@ -319,20 +327,23 @@ class Session(Object):
_ffi_api.SessionInitCCL(self, ccl, ShapeTuple(device_ids)) # type:
ignore # pylint: disable=no-member
self._clear_ipc_memory_pool()
- def broadcast(self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] =
None) -> DRef:
+ def broadcast(
+ self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None,
in_group: bool = True
+ ) -> DRef:
"""Broadcast an array to all workers
Parameters
----------
src: Union[np.ndarray, NDArray]
-
The array to be broadcasted.
dst: Optional[DRef]
-
The output array. If None, an array matching the shape
and dtype of `src` will be allocated on each worker.
+ in_group: bool
+ Whether the broadcast operation performs globally or in group as
default.
+
Returns
-------
output_array: DRef
@@ -349,38 +360,48 @@ class Session(Object):
dst = self.empty(src.shape, src.dtype)
src_dref = self.copy_to_worker_0(src)
- self.broadcast_from_worker0(src_dref, dst)
+ self.broadcast_from_worker0(src_dref, dst, in_group)
return dst
- def broadcast_from_worker0(self, src: DRef, dst: DRef) -> DRef:
+ def broadcast_from_worker0(self, src: DRef, dst: DRef, in_group: bool =
True) -> DRef:
"""Broadcast an array from worker-0 to all other workers.
Parameters
----------
- array : DRef
- The array to be broadcasted in-place
+ src: Union[np.ndarray, NDArray]
+ The array to be broadcasted.
+
+ dst: Optional[DRef]
+ The output array. If None, an array matching the shape
+ and dtype of `src` will be allocated on each worker.
+
+ in_group: bool
+ Whether the broadcast operation performs globally or in group as
default.
"""
func = self._get_cached_method("runtime.disco.broadcast_from_worker0")
- func(src, dst)
+ func(src, in_group, dst)
- def scatter(self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] =
None) -> DRef:
+ def scatter(
+ self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None,
in_group: bool = True
+ ) -> DRef:
"""Scatter an array across all workers
Parameters
----------
src: Union[np.ndarray, NDArray]
-
The array to be scattered. The first dimension of this
array, `src.shape[0]`, must be equal to the number of
workers.
dst: Optional[DRef]
-
The output array. If None, an array with compatible shape
and the same dtype as `src` will be allocated on each
worker.
+ in_group: bool
+ Whether the scatter operation performs globally or in group as
default.
+
Returns
-------
output_array: DRef
@@ -399,41 +420,54 @@ class Session(Object):
dst = self.empty(src.shape[1:], src.dtype)
src_dref = self.copy_to_worker_0(src)
- self.scatter_from_worker0(src_dref, dst)
+ self.scatter_from_worker0(src_dref, dst, in_group)
return dst
- def scatter_from_worker0(self, from_array: DRef, to_array: DRef) -> None:
+ def scatter_from_worker0(self, from_array: DRef, to_array: DRef, in_group:
bool = True) -> None:
"""Scatter an array from worker-0 to all other workers.
Parameters
----------
- from_array : DRef
- The array to be scattered from.
- to_array : DRef
- The array to be scattered to.
+ src: Union[np.ndarray, NDArray]
+ The array to be scattered. The first dimension of this
+ array, `src.shape[0]`, must be equal to the number of
+ workers.
+
+ dst: Optional[DRef]
+ The output array. If None, an array with compatible shape
+ and the same dtype as `src` will be allocated on each
+ worker.
+
+ in_group: bool
+ Whether the scatter operation performs globally or in group as
default.
"""
func = self._get_cached_method("runtime.disco.scatter_from_worker0")
- func(from_array, to_array)
+ func(from_array, in_group, to_array)
- def gather_to_worker0(self, from_array: DRef, to_array: DRef) -> None:
+ def gather_to_worker0(self, from_array: DRef, to_array: DRef, in_group:
bool = True) -> None:
"""Gather an array from all other workers to worker-0.
Parameters
----------
from_array : DRef
The array to be gathered from.
+
to_array : DRef
The array to be gathered to.
+
+ in_group: bool
+ Whether the gather operation performs globally or in group as
default.
"""
func = self._get_cached_method("runtime.disco.gather_to_worker0")
- func(from_array, to_array)
+ func(from_array, in_group, to_array)
def allreduce(
self,
src: DRef,
dst: DRef,
op: str = "sum", # pylint: disable=invalid-name
+ in_group: bool = True,
) -> DRef:
"""Perform an allreduce operation on an array.
@@ -441,6 +475,7 @@ class Session(Object):
----------
array : DRef
The array to be reduced.
+
op : str = "sum"
The reduce operation to be performed. Available options are:
- "sum"
@@ -448,17 +483,21 @@ class Session(Object):
- "min"
- "max"
- "avg"
+
+ in_group : bool
+ Whether the reduce operation performs globally or in group as
default.
"""
if op not in REDUCE_OPS:
raise ValueError(f"Unsupported reduce op: {op}. Available ops are:
{REDUCE_OPS.keys()}")
op = ShapeTuple([REDUCE_OPS[op]])
func = self._get_cached_method("runtime.disco.allreduce")
- func(src, op, dst)
+ func(src, op, in_group, dst)
def allgather(
self,
src: DRef,
dst: DRef,
+ in_group: bool = True,
) -> DRef:
"""Perform an allgather operation on an array.
@@ -466,11 +505,15 @@ class Session(Object):
----------
src : DRef
The array to be gathered from.
+
dst : DRef
The array to be gathered to.
+
+ in_group : bool
+ Whether the reduce operation performs globally or in group as
default.
"""
func = self._get_cached_method("runtime.disco.allgather")
- func(src, dst)
+ func(src, in_group, dst)
def _clear_ipc_memory_pool(self):
# Clear the IPC memory allocator when the allocator exists.
@@ -483,11 +526,12 @@ class Session(Object):
class ThreadedSession(Session):
"""A Disco session backed by multi-threading."""
- def __init__(self, num_workers: int) -> None:
+ def __init__(self, num_workers: int, num_groups: int = 1) -> None:
"""Create a disco session backed by multiple threads in the same
process."""
self.__init_handle_by_constructor__(
_ffi_api.SessionThreaded, # type: ignore # pylint:
disable=no-member
num_workers,
+ num_groups,
)
@@ -495,10 +539,13 @@ class ThreadedSession(Session):
class ProcessSession(Session):
"""A Disco session backed by pipe-based multi-processing."""
- def __init__(self, num_workers: int, entrypoint: str =
"tvm.exec.disco_worker") -> None:
+ def __init__(
+ self, num_workers: int, num_groups: int = 1, entrypoint: str =
"tvm.exec.disco_worker"
+ ) -> None:
self.__init_handle_by_constructor__(
_ffi_api.SessionProcess, # type: ignore # pylint:
disable=no-member
num_workers,
+ num_groups,
"runtime.disco.create_process_pool",
entrypoint,
)
diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc
index c0fe6f4d88..092727cb51 100644
--- a/src/relax/op/ccl/ccl.cc
+++ b/src/relax/op/ccl/ccl.cc
@@ -27,9 +27,10 @@ namespace relax {
/* relax.ccl.allreduce */
TVM_REGISTER_NODE_TYPE(AllReduceAttrs);
-Expr allreduce(Expr x, String op_type) {
+Expr allreduce(Expr x, String op_type, bool in_group) {
ObjectPtr<AllReduceAttrs> attrs = make_object<AllReduceAttrs>();
attrs->op_type = std::move(op_type);
+ attrs->in_group = std::move(in_group);
static const Op& op = Op::Get("relax.ccl.allreduce");
return Call(op, {std::move(x)}, Attrs{attrs}, {});
@@ -51,19 +52,24 @@ TVM_REGISTER_OP("relax.ccl.allreduce")
.set_attr<Bool>("FPurity", Bool(true));
/* relax.ccl.allgather */
-Expr allgather(Expr x, Expr num_workers) {
+TVM_REGISTER_NODE_TYPE(AllGatherAttrs);
+
+Expr allgather(Expr x, int num_workers, bool in_group) {
+ ObjectPtr<AllGatherAttrs> attrs = make_object<AllGatherAttrs>();
+ attrs->num_workers = std::move(num_workers);
+ attrs->in_group = std::move(in_group);
+
static const Op& op = Op::Get("relax.ccl.allgather");
- return Call(op, {std::move(x), std::move(num_workers)});
+ return Call(op, {std::move(x)}, Attrs{attrs}, {});
}
TVM_REGISTER_GLOBAL("relax.op.ccl.allgather").set_body_typed(allgather);
StructInfo InferStructInfoAllGather(const Call& call, const BlockBuilder& ctx)
{
- CHECK_EQ(call->args.size(), 2);
- auto input_sinfo = Downcast<TensorStructInfo>(call->args[0]->struct_info_);
- auto num_workers_sinfo =
Downcast<PrimStructInfo>(call->args[1]->struct_info_);
+ TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
- auto num_workers = num_workers_sinfo->value;
+ const auto* attrs = call->attrs.as<AllGatherAttrs>();
+ int num_workers = attrs->num_workers;
DataType output_dtype = input_sinfo->dtype;
auto input_shape = input_sinfo->GetShape();
@@ -71,7 +77,7 @@ StructInfo InferStructInfoAllGather(const Call& call, const
BlockBuilder& ctx) {
return input_sinfo;
}
Array<PrimExpr> output_shape = input_shape.value();
- output_shape.Set(0, floor(output_shape[0] * num_workers.value()));
+ output_shape.Set(0, floor(output_shape[0] * num_workers));
return TensorStructInfo(ShapeExpr(output_shape), output_dtype,
input_sinfo->vdevice);
}
diff --git a/src/relax/op/ccl/ccl.h b/src/relax/op/ccl/ccl.h
index 3e7f0220c9..82ea393567 100644
--- a/src/relax/op/ccl/ccl.h
+++ b/src/relax/op/ccl/ccl.h
@@ -33,10 +33,10 @@ namespace tvm {
namespace relax {
/*! \brief AllReduce. */
-Expr allreduce(Expr data, String op_type);
+Expr allreduce(Expr data, String op_type, bool in_group);
/*! \brief AllGather. */
-Expr allgather(Expr data, Expr num_workers);
+Expr allgather(Expr data, int num_workers, bool in_group);
/*! \brief Broadcast data from worker-0 to all other workers. */
Expr broadcast_from_worker0(Expr data);
diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc
index 26d1c22ee9..0cb2ee6f5d 100644
--- a/src/runtime/disco/builtin.cc
+++ b/src/runtime/disco/builtin.cc
@@ -79,22 +79,24 @@ const PackedFunc& GetCCLFunc(const char* name) {
return *pf;
}
-void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv) {
- GetCCLFunc("allreduce")(send, static_cast<int>(reduce_kind), recv);
+void AllReduce(NDArray send, ReduceKind reduce_kind, bool in_group, NDArray
recv) {
+ GetCCLFunc("allreduce")(send, static_cast<int>(reduce_kind), in_group, recv);
}
-void AllGather(NDArray send, NDArray recv) { GetCCLFunc("allgather")(send,
recv); }
+void AllGather(NDArray send, bool in_group, NDArray recv) {
+ GetCCLFunc("allgather")(send, in_group, recv);
+}
-TVM_DLL void BroadcastFromWorker0(NDArray send, NDArray recv) {
- GetCCLFunc("broadcast_from_worker0")(send, recv);
+TVM_DLL void BroadcastFromWorker0(NDArray send, bool in_group, NDArray recv) {
+ GetCCLFunc("broadcast_from_worker0")(send, in_group, recv);
}
-TVM_DLL void ScatterFromWorker0(Optional<NDArray> send, NDArray recv) {
- GetCCLFunc("scatter_from_worker0")(send, recv);
+TVM_DLL void ScatterFromWorker0(Optional<NDArray> send, bool in_group, NDArray
recv) {
+ GetCCLFunc("scatter_from_worker0")(send, in_group, recv);
}
-void GatherToWorker0(NDArray send, Optional<NDArray> recv) {
- GetCCLFunc("gather_to_worker0")(send, recv);
+void GatherToWorker0(NDArray send, bool in_group, Optional<NDArray> recv) {
+ GetCCLFunc("gather_to_worker0")(send, in_group, recv);
}
void RecvFromWorker0(NDArray buffer) {
GetCCLFunc("recv_from_worker0")(buffer); }
@@ -110,9 +112,13 @@ void SyncWorker() {
TVM_REGISTER_GLOBAL("runtime.disco.load_vm_module").set_body_typed(LoadVMModule);
TVM_REGISTER_GLOBAL("runtime.disco.empty")
- .set_body_typed([](ShapeTuple shape, DataType dtype, Device device,
- bool worker0_only) -> Optional<NDArray> {
- if (worker0_only && WorkerId()) {
+ .set_body_typed([](ShapeTuple shape, DataType dtype, Device device, bool
worker0_only,
+ bool in_group) -> Optional<NDArray> {
+ int worker_id = WorkerId();
+ int group_size =
+ DiscoWorker::ThreadLocal()->num_workers /
DiscoWorker::ThreadLocal()->num_groups;
+ bool is_worker0 = (worker_id == 0 && !in_group) || (in_group &&
worker_id % group_size == 0);
+ if (worker0_only && !is_worker0) {
return NullOpt;
} else {
return DiscoEmptyNDArray(shape, dtype, device);
@@ -120,10 +126,10 @@ TVM_REGISTER_GLOBAL("runtime.disco.empty")
});
TVM_REGISTER_GLOBAL("runtime.disco.allreduce")
- .set_body_typed([](NDArray send, ShapeTuple reduce_kind, NDArray recv) {
+ .set_body_typed([](NDArray send, ShapeTuple reduce_kind, bool in_group,
NDArray recv) {
int kind = IntegerFromShapeTuple(reduce_kind);
CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " <<
kind;
- AllReduce(send, static_cast<ReduceKind>(kind), recv);
+ AllReduce(send, static_cast<ReduceKind>(kind), in_group, recv);
});
TVM_REGISTER_GLOBAL("runtime.disco.allgather").set_body_typed(AllGather);
TVM_REGISTER_GLOBAL("runtime.disco.broadcast_from_worker0").set_body_typed(BroadcastFromWorker0);
diff --git a/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc
b/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc
index fec5abec86..490217d62c 100644
--- a/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc
+++ b/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc
@@ -47,8 +47,8 @@ std::vector<cudaIpcMemHandle_t>
AllGatherIPCHandles(nccl::CCLThreadLocalContext*
CUDA_CALL(cudaMalloc(&d_src, CUDA_IPC_HANDLE_SIZE));
CUDA_CALL(cudaMalloc(&d_dst, CUDA_IPC_HANDLE_SIZE *
ctx->worker->num_workers));
CUDA_CALL(cudaMemcpy(d_src, &local_handle, CUDA_IPC_HANDLE_SIZE,
cudaMemcpyHostToDevice));
- NCCL_CALL(
- ncclAllGather(d_src, d_dst, CUDA_IPC_HANDLE_SIZE, ncclChar, ctx->comm,
/*stream=*/nullptr));
+ NCCL_CALL(ncclAllGather(d_src, d_dst, CUDA_IPC_HANDLE_SIZE, ncclChar,
ctx->global_comm,
+ /*stream=*/nullptr));
std::vector<char> serial_handles(CUDA_IPC_HANDLE_SIZE *
ctx->worker->num_workers, 0);
CUDA_CALL(cudaMemcpy(serial_handles.data(), d_dst,
CUDA_IPC_HANDLE_SIZE * ctx->worker->num_workers,
cudaMemcpyDefault));
diff --git a/src/runtime/disco/cuda_ipc/custom_allreduce.cc
b/src/runtime/disco/cuda_ipc/custom_allreduce.cc
index 98fd777b83..d969005f94 100644
--- a/src/runtime/disco/cuda_ipc/custom_allreduce.cc
+++ b/src/runtime/disco/cuda_ipc/custom_allreduce.cc
@@ -65,6 +65,8 @@ inline bool CanApplyTwoShotAllReduce(int64_t num_elements,
DLDataType dtype, int
void CustomAllReduce(DLTensor* send, int strategy, DLTensor* recv) {
int64_t num_elements = TensorSize(send);
nccl::CCLThreadLocalContext* ctx = nccl::CCLThreadLocalContext::Get();
+ CHECK_EQ(ctx->worker->num_groups, 1)
+ << "Custom AllReduce for multiple group is not yet implemented.";
tensorrt_llm::AllReduceStrategyType strategy_ =
static_cast<tensorrt_llm::AllReduceStrategyType>(strategy);
@@ -79,7 +81,7 @@ void CustomAllReduce(DLTensor* send, int strategy, DLTensor*
recv) {
deviceStream_t stream = ctx->GetDefaultStream();
NCCL_CALL(ncclAllReduce(send->data, recv->data, num_elements,
/*datatype=*/nccl::AsNCCLDataType(DataType(send->dtype)),
- /*op=*/ncclSum, ctx->comm, stream));
+ /*op=*/ncclSum, ctx->global_comm, stream));
return;
}
diff --git a/src/runtime/disco/disco_worker_thread.h
b/src/runtime/disco/disco_worker_thread.h
index 67742cdd04..8d6b44396f 100644
--- a/src/runtime/disco/disco_worker_thread.h
+++ b/src/runtime/disco/disco_worker_thread.h
@@ -47,12 +47,14 @@ class DiscoWorkerThread {
* \brief Construct a worker thread.
* \param worker_id The id of the worker.
* \param num_workers The total number of workers.
+ * \param num_groups The total number of worker groups.
* \param worker_zero_data_ The data shared between worker-0 and the
controler. It's a nullptr if
* the worker is not worker-0.
* \note This method is implemented in threaded worker, because it depends
on creation of a
* sub-class of DiscoChannel, DiscoThreadChannel, which is hidden from the
public interface.
*/
- explicit DiscoWorkerThread(int worker_id, int num_workers, WorkerZeroData*
worker_zero_data_);
+ explicit DiscoWorkerThread(int worker_id, int num_workers, int num_groups,
+ WorkerZeroData* worker_zero_data_);
/*! \brief Move constructor. */
explicit DiscoWorkerThread(DiscoWorkerThread&& other)
diff --git a/src/runtime/disco/loader.cc b/src/runtime/disco/loader.cc
index 7a5d978946..efe42539cb 100644
--- a/src/runtime/disco/loader.cc
+++ b/src/runtime/disco/loader.cc
@@ -326,19 +326,19 @@ NDArray ShardLoaderObj::Load(int weight_index) const {
for (const ShardInfo::ShardFunc& shard_func :
param_info.shard_info.funcs) {
w = this->ApplyShardFunc(shard_func, w);
}
- ScatterFromWorker0(w, recv);
+ ScatterFromWorker0(w, /*in_group=*/false, recv);
} else {
- ScatterFromWorker0(NullOpt, recv);
+ ScatterFromWorker0(NullOpt, /*in_group=*/false, recv);
}
return recv;
} else {
if (worker_id == 0) {
NDArray w = LoadDirect(weight_index);
- BroadcastFromWorker0(w, w);
+ BroadcastFromWorker0(w, /*in_group=*/false, w);
return w;
} else {
NDArray w = NDArray::Empty(param->shape, param->dtype, device);
- BroadcastFromWorker0(w, w);
+ BroadcastFromWorker0(w, /*in_group=*/false, w);
return w;
}
}
diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc
index bba42ed3bd..2d2c528b52 100644
--- a/src/runtime/disco/nccl/nccl.cc
+++ b/src/runtime/disco/nccl/nccl.cc
@@ -72,9 +72,12 @@ void InitCCLPerWorker(IntTuple device_ids, std::string
unique_id_bytes) {
<< "ValueError: The length of unique_id must be " <<
NCCL_UNIQUE_ID_BYTES << ", but got "
<< unique_id_bytes.size() << ".";
- CHECK(!ctx->comm) << "Cannot initialize CCL, "
- << "the previous thread-global comm still exists, "
- << "and has not been destructed";
+ CHECK(!ctx->global_comm) << "Cannot initialize CCL, "
+ << "the previous thread-global comm still exists, "
+ << "and has not been destructed";
+ CHECK(!ctx->group_comm) << "Cannot initialize CCL, "
+ << "the previous thread-group comm still exists, "
+ << "and has not been destructed";
CHECK(!ctx->default_stream) << "Cannot initialize CCL, "
<< "the previous thread-global stream still
exists, "
<< "and has not been destructed";
@@ -96,34 +99,41 @@ void InitCCLPerWorker(IntTuple device_ids, std::string
unique_id_bytes) {
// Initialize the communicator
ncclUniqueId id;
std::memcpy(id.internal, unique_id_bytes.data(), NCCL_UNIQUE_ID_BYTES);
- NCCL_CALL(ncclCommInitRank(&ctx->comm, worker->num_workers, id,
worker->worker_id));
+ int group_size = worker->num_workers / worker->num_groups;
+ NCCL_CALL(ncclCommInitRank(&ctx->global_comm, worker->num_workers, id,
worker->worker_id));
+ NCCL_CALL(ncclCommSplit(ctx->global_comm, worker->worker_id / group_size,
+ worker->worker_id % group_size, &ctx->group_comm,
NULL));
}
-void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv) {
+void AllReduce(NDArray send, ReduceKind reduce_kind, bool in_group, NDArray
recv) {
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
ShapeTuple shape = send.Shape();
int64_t numel = shape->Product();
deviceStream_t stream = ctx->GetDefaultStream();
NCCL_CALL(ncclAllReduce(send->data, recv->data, numel,
/*datatype=*/AsNCCLDataType(DataType(send->dtype)),
- /*op=*/AsNCCLRedOp(reduce_kind), ctx->comm, stream));
+ /*op=*/AsNCCLRedOp(reduce_kind),
+ in_group ? ctx->group_comm : ctx->global_comm,
stream));
}
-void AllGather(NDArray send, NDArray recv) {
+void AllGather(NDArray send, bool in_group, NDArray recv) {
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
ShapeTuple shape = send.Shape();
int64_t numel = shape->Product();
deviceStream_t stream = ctx->GetDefaultStream();
NCCL_CALL(ncclAllGather(send->data, recv->data, numel,
- /*datatype=*/AsNCCLDataType(DataType(send->dtype)),
ctx->comm, stream));
+ /*datatype=*/AsNCCLDataType(DataType(send->dtype)),
+ in_group ? ctx->group_comm : ctx->global_comm,
stream));
}
-void BroadcastFromWorker0(Optional<NDArray> send, NDArray recv) {
+void BroadcastFromWorker0(Optional<NDArray> send, bool in_group, NDArray recv)
{
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
+ int worker_id = ctx->worker->worker_id;
+ int group_size = ctx->worker->num_workers / ctx->worker->num_groups;
+ bool is_sender = (worker_id == 0 && !in_group) || (in_group && worker_id %
group_size == 0);
const void* send_data = [&]() -> const void* {
- int worker_id = ctx->worker->worker_id;
- if (worker_id == 0) {
+ if (is_sender) {
CHECK(send.defined());
CHECK(send.value().Shape()->Product() == recv.Shape()->Product());
return send.value()->data;
@@ -136,25 +146,28 @@ void BroadcastFromWorker0(Optional<NDArray> send, NDArray
recv) {
deviceStream_t stream = ctx->GetDefaultStream();
NCCL_CALL(ncclBroadcast(send_data, recv->data, numel,
/*datatype=*/AsNCCLDataType(DataType(recv->dtype)),
- /*root=*/0, ctx->comm, stream));
+ /*root=*/0, in_group ? ctx->group_comm :
ctx->global_comm, stream));
}
-void ScatterFromWorker0(Optional<NDArray> send, NDArray recv) {
+void ScatterFromWorker0(Optional<NDArray> send, bool in_group, NDArray recv) {
CHECK(recv.defined()) << "ValueError: buffer `recv` must not be None";
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
int worker_id = ctx->worker->worker_id;
int num_workers = ctx->worker->num_workers;
+ int group_size = num_workers / ctx->worker->num_groups;
+ bool is_sender = (worker_id == 0 && !in_group) || (in_group && worker_id %
group_size == 0);
+ int num_receiver = in_group ? group_size : num_workers;
deviceStream_t stream = ctx->GetDefaultStream();
- if (worker_id == 0) {
+ if (is_sender) {
CHECK(send.defined()) << "ValueError: buffer `send` must be provided when
worker_id == 0.";
NDArray buffer = send.value();
int64_t numel = buffer.Shape()->Product();
- CHECK_EQ(numel % num_workers, 0) << "ValueError: Scattering evenly
requires that the number "
- "of elements in the buffer to be "
- "divisible by the number of workers,
but got numel = "
- << numel << " and " << num_workers << "
workers.";
+ CHECK_EQ(numel % num_receiver, 0) << "ValueError: Scattering evenly
requires that the number "
+ "of elements in the buffer to be "
+ "divisible by the number of workers,
but got numel = "
+ << numel << " and " << num_receiver << "
workers.";
DataType dtype(buffer->dtype);
- int64_t numel_per_shard = numel / num_workers;
+ int64_t numel_per_shard = numel / num_receiver;
int64_t bytes_per_shard = numel_per_shard * dtype.bytes();
CHECK_EQ(numel_per_shard, recv.Shape()->Product())
<< "ValueError: The number of elements in buffer `recv` must be the
same as each shard "
@@ -163,40 +176,45 @@ void ScatterFromWorker0(Optional<NDArray> send, NDArray
recv) {
<< numel << ", but `recv.size` is " << recv.Shape()->Product() << ".";
NCCL_CALL(ncclGroupStart());
uint8_t* data = static_cast<uint8_t*>(buffer->data);
- for (int i = 0; i < num_workers; ++i) {
- NCCL_CALL(ncclSend(data, numel_per_shard, AsNCCLDataType(dtype), i,
ctx->comm, stream));
+ for (int i = 0; i < num_receiver; ++i) {
+ NCCL_CALL(ncclSend(data, numel_per_shard, AsNCCLDataType(dtype), i,
+ in_group ? ctx->group_comm : ctx->global_comm,
stream));
data += bytes_per_shard;
}
} else {
if (send.defined()) {
- LOG(WARNING) << "Buffer `send` must be None when worker_id != 0, but got
"
- "send = "
+ LOG(WARNING) << "ValueError: buffer `send` must be None when (worker_id
!= 0 && !in_group) "
+ "or (worker_id % group_size != 0 && in_group). However,
got send = "
<< send.get() << ". This will be ignored.";
}
NCCL_CALL(ncclGroupStart());
}
int64_t numel = recv.Shape()->Product();
DataType dtype(recv->dtype);
- NCCL_CALL(ncclRecv(recv->data, numel, AsNCCLDataType(dtype), 0, ctx->comm,
stream));
+ NCCL_CALL(ncclRecv(recv->data, numel, AsNCCLDataType(dtype), 0,
+ in_group ? ctx->group_comm : ctx->global_comm, stream));
NCCL_CALL(ncclGroupEnd());
}
-void GatherToWorker0(NDArray send, Optional<NDArray> recv) {
+void GatherToWorker0(NDArray send, bool in_group, Optional<NDArray> recv) {
CHECK(send.defined()) << "ValueError: buffer `send` must not be None";
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
int worker_id = ctx->worker->worker_id;
int num_workers = ctx->worker->num_workers;
+ int group_size = num_workers / ctx->worker->num_groups;
+ bool is_sender = (worker_id == 0 && !in_group) || (in_group && worker_id %
group_size == 0);
+ int num_receiver = in_group ? group_size : num_workers;
deviceStream_t stream = ctx->GetDefaultStream();
- if (worker_id == 0) {
+ if (is_sender) {
CHECK(recv.defined()) << "ValueError: buffer `recv` must be provided when
worker_id == 0.";
NDArray buffer = recv.value();
int64_t numel = buffer.Shape()->Product();
- CHECK_EQ(numel % num_workers, 0) << "ValueError: Gathering evenly requires
that the number "
- "of elements in the buffer to be "
- "divisible by the number of workers,
but got numel = "
- << numel << " and " << num_workers << "
workers.";
+ CHECK_EQ(numel % num_receiver, 0) << "ValueError: Gathering evenly
requires that the number "
+ "of elements in the buffer to be "
+ "divisible by the number of workers,
but got numel = "
+ << numel << " and " << num_receiver << "
workers.";
DataType dtype(buffer->dtype);
- int64_t numel_per_shard = numel / num_workers;
+ int64_t numel_per_shard = numel / num_receiver;
int64_t bytes_per_shard = numel_per_shard * dtype.bytes();
CHECK_EQ(numel_per_shard, send.Shape()->Product())
<< "ValueError: The number of elements in buffer `send` must be the
same as each shard "
@@ -205,21 +223,23 @@ void GatherToWorker0(NDArray send, Optional<NDArray>
recv) {
<< numel << ", but `send.size` is " << send.Shape()->Product() << ".";
NCCL_CALL(ncclGroupStart());
uint8_t* data = static_cast<uint8_t*>(buffer->data);
- for (int i = 0; i < num_workers; ++i) {
- NCCL_CALL(ncclRecv(data, numel_per_shard, AsNCCLDataType(dtype), i,
ctx->comm, stream));
+ for (int i = 0; i < num_receiver; ++i) {
+ NCCL_CALL(ncclRecv(data, numel_per_shard, AsNCCLDataType(dtype), i,
+ in_group ? ctx->group_comm : ctx->global_comm,
stream));
data += bytes_per_shard;
}
} else {
if (recv.defined()) {
- LOG(WARNING) << "ValueError: buffer `recv` must be None when worker_id
!= 0. However, got "
- "recv = "
+ LOG(WARNING) << "ValueError: buffer `recv` must be None when (worker_id
!= 0 && !in_group) "
+ "or (worker_id % group_size != 0 && in_group). However,
got recv = "
<< recv.get() << ". This will be ignored.";
}
NCCL_CALL(ncclGroupStart());
}
int64_t numel = send.Shape()->Product();
DataType dtype(send->dtype);
- NCCL_CALL(ncclSend(send->data, numel, AsNCCLDataType(dtype), 0, ctx->comm,
stream));
+ NCCL_CALL(ncclSend(send->data, numel, AsNCCLDataType(dtype), 0,
+ in_group ? ctx->group_comm : ctx->global_comm, stream));
NCCL_CALL(ncclGroupEnd());
}
@@ -230,7 +250,7 @@ void RecvFromWorker0(NDArray buffer) {
<< "ValueError: Worker 0 is not allowed to call RecvFromWorker0.";
NCCL_CALL(ncclGroupStart());
NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(),
AsNCCLDataType(buffer.DataType()), 0,
- ctx->comm, stream));
+ ctx->global_comm, stream));
NCCL_CALL(ncclGroupEnd());
}
@@ -248,12 +268,14 @@ TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME
".init_ccl").set_body_ty
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl_per_worker")
.set_body_typed(InitCCLPerWorker);
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".allreduce")
- .set_body_typed([](NDArray send, int kind, NDArray recv) {
+ .set_body_typed([](NDArray send, int kind, bool in_group, NDArray recv) {
CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " <<
kind;
- nccl::AllReduce(send, static_cast<ReduceKind>(kind), recv);
+ nccl::AllReduce(send, static_cast<ReduceKind>(kind), in_group, recv);
});
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".allgather")
- .set_body_typed([](NDArray send, NDArray recv) { nccl::AllGather(send,
recv); });
+ .set_body_typed([](NDArray send, bool in_group, NDArray recv) {
+ nccl::AllGather(send, in_group, recv);
+ });
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME
".broadcast_from_worker0")
.set_body_typed(BroadcastFromWorker0);
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME
".scatter_from_worker0")
diff --git a/src/runtime/disco/nccl/nccl_context.h
b/src/runtime/disco/nccl/nccl_context.h
index 3fb281f2cb..730479b61a 100644
--- a/src/runtime/disco/nccl/nccl_context.h
+++ b/src/runtime/disco/nccl/nccl_context.h
@@ -121,14 +121,19 @@ struct CCLThreadLocalContext {
DiscoWorker* worker = nullptr;
int device_id;
deviceStream_t default_stream = nullptr;
- ncclComm_t comm = nullptr;
+ ncclComm_t global_comm = nullptr;
+ ncclComm_t group_comm = nullptr;
~CCLThreadLocalContext() { Clear(); }
void Clear() {
- if (comm) {
- NCCL_CALL(ncclCommDestroy(comm));
- comm = nullptr;
+ if (group_comm) {
+ NCCL_CALL(ncclCommDestroy(group_comm));
+ group_comm = nullptr;
+ }
+ if (global_comm) {
+ NCCL_CALL(ncclCommDestroy(global_comm));
+ global_comm = nullptr;
}
if (default_stream) {
StreamDestroy(default_stream);
diff --git a/src/runtime/disco/process_session.cc
b/src/runtime/disco/process_session.cc
index 179010db8a..7c8d0796dd 100644
--- a/src/runtime/disco/process_session.cc
+++ b/src/runtime/disco/process_session.cc
@@ -154,9 +154,10 @@ class DiscoProcessChannel final : public DiscoChannel {
class ProcessSessionObj final : public BcastSessionObj {
public:
- explicit ProcessSessionObj(int num_workers, PackedFunc process_pool)
+ explicit ProcessSessionObj(int num_workers, int num_groups, PackedFunc
process_pool)
: process_pool_(process_pool),
- worker_0_(std::make_unique<DiscoWorkerThread>(0, num_workers,
&worker_zero_data_)) {
+ worker_0_(
+ std::make_unique<DiscoWorkerThread>(0, num_workers, num_groups,
&worker_zero_data_)) {
std::vector<int64_t> read_fds;
std::vector<int64_t> write_fds;
read_fds.reserve(num_workers - 1);
@@ -258,18 +259,24 @@ class ProcessSessionObj final : public BcastSessionObj {
TVM_REGISTER_OBJECT_TYPE(DiscoDebugObject);
TVM_REGISTER_OBJECT_TYPE(ProcessSessionObj);
-Session Session::ProcessSession(int num_workers, String process_pool_creator,
String entrypoint) {
+Session Session::ProcessSession(int num_workers, int num_group, String
process_pool_creator,
+ String entrypoint) {
+ CHECK_EQ(num_workers % num_group, 0)
+ << "The number of workers should be divisible by the number of worker
group.";
const PackedFunc* pf = Registry::Get(process_pool_creator);
CHECK(pf) << "ValueError: Cannot find function " << process_pool_creator
<< " in the registry. Please check if it is registered.";
- PackedFunc process_pool = (*pf)(num_workers, entrypoint);
- auto n = make_object<ProcessSessionObj>(num_workers, process_pool);
+ PackedFunc process_pool = (*pf)(num_workers, num_group, entrypoint);
+ auto n = make_object<ProcessSessionObj>(num_workers, num_group,
process_pool);
return Session(n);
}
-void WorkerProcess(int worker_id, int num_workers, int64_t read_fd, int64_t
write_fd) {
+void WorkerProcess(int worker_id, int num_workers, int num_group, int64_t
read_fd,
+ int64_t write_fd) {
+ CHECK_EQ(num_workers % num_group, 0)
+ << "The number of workers should be divisible by the number of worker
group.";
DiscoProcessChannel channel(read_fd, write_fd);
- DiscoWorker worker(worker_id, num_workers, nullptr, &channel);
+ DiscoWorker worker(worker_id, num_workers, num_group, nullptr, &channel);
worker.MainLoop();
}
diff --git a/src/runtime/disco/threaded_session.cc
b/src/runtime/disco/threaded_session.cc
index 22f906b809..cc9a311a6b 100644
--- a/src/runtime/disco/threaded_session.cc
+++ b/src/runtime/disco/threaded_session.cc
@@ -133,20 +133,20 @@ class DiscoThreadChannel final : public DiscoChannel {
DiscoThreadedMessageQueue worker_to_controler_;
};
-DiscoWorkerThread::DiscoWorkerThread(int worker_id, int num_workers,
+DiscoWorkerThread::DiscoWorkerThread(int worker_id, int num_workers, int
num_groups,
WorkerZeroData* worker_zero_data_)
: channel(std::make_unique<DiscoThreadChannel>()),
- worker(
- std::make_unique<DiscoWorker>(worker_id, num_workers,
worker_zero_data_, channel.get())),
+ worker(std::make_unique<DiscoWorker>(worker_id, num_workers, num_groups,
worker_zero_data_,
+ channel.get())),
thread(std::make_unique<std::thread>([worker = this->worker.get()] {
worker->MainLoop(); })) {
}
class ThreadedSessionObj final : public BcastSessionObj {
public:
- explicit ThreadedSessionObj(int num_workers) {
+ explicit ThreadedSessionObj(int num_workers, int num_groups) {
for (int i = 0; i < num_workers; ++i) {
WorkerZeroData* data = (i == 0) ? &worker_zero_data_ : nullptr;
- workers_.emplace_back(i, num_workers, data);
+ workers_.emplace_back(i, num_workers, num_groups, data);
}
}
@@ -185,8 +185,10 @@ class ThreadedSessionObj final : public BcastSessionObj {
TVM_REGISTER_OBJECT_TYPE(ThreadedSessionObj);
-Session Session::ThreadedSession(int num_workers) {
- ObjectPtr<ThreadedSessionObj> n =
make_object<ThreadedSessionObj>(num_workers);
+Session Session::ThreadedSession(int num_workers, int num_group) {
+ CHECK_EQ(num_workers % num_group, 0)
+ << "The number of workers should be divisible by the number of worker
group.";
+ ObjectPtr<ThreadedSessionObj> n =
make_object<ThreadedSessionObj>(num_workers, num_group);
return Session(std::move(n));
}
diff --git a/tests/python/disco/test_callback.py
b/tests/python/disco/test_callback.py
index 6e2dc9b747..3f8d5e9e52 100644
--- a/tests/python/disco/test_callback.py
+++ b/tests/python/disco/test_callback.py
@@ -30,16 +30,17 @@ from tvm.script import relax as R, tir as T
@tvm.testing.requires_nccl
def test_callback():
+ """Simulate lazy loading of parameters in a callback
+
+ The output of a lazy parameter loading, which would accept a
+ callback to load the parameters.
+ """
+
@R.function
def transform_params(
rank_arg: R.Prim(value="rank"),
fget_item: R.Callable([R.Object, R.Prim("int64")], R.Object),
):
- """Simulate lazy loading of parameters in a callback
-
- The output of a lazy parameter loading, which would accept a
- callback to load the parameters.
- """
rank = T.int64()
A = fget_item(R.str("A"), R.prim_value(0))
diff --git a/tests/python/disco/test_ccl.py b/tests/python/disco/test_ccl.py
index 5831f245df..6c63f64554 100644
--- a/tests/python/disco/test_ccl.py
+++ b/tests/python/disco/test_ccl.py
@@ -78,6 +78,42 @@ def test_allreduce(session_kind, ccl):
np.testing.assert_equal(result, expected)
[email protected]("session_kind", _all_session_kinds)
[email protected]("ccl", _ccl)
+def test_group_allreduce(session_kind, ccl):
+ devices = [0, 1, 2, 3]
+ sess = session_kind(num_workers=len(devices), num_groups=2)
+ sess.init_ccl(ccl, *devices)
+
+ array_1 = np.arange(12, dtype="float32").reshape(3, 4)
+ array_2 = np.arange(start=1, stop=-11, step=-1,
dtype="float32").reshape(3, 4)
+ array_3 = np.arange(30, dtype="float32").reshape(5, 6)
+ array_4 = np.arange(start=1, stop=-29, step=-1,
dtype="float32").reshape(5, 6)
+ d_array_1 = sess.empty((3, 4), "float32")
+ d_array_2 = sess.empty((5, 6), "float32")
+ d_array_1.debug_copy_from(0, array_1)
+ d_array_1.debug_copy_from(1, array_2)
+ d_array_2.debug_copy_from(2, array_3)
+ d_array_2.debug_copy_from(3, array_4)
+ for op, np_op in [ # pylint: disable=invalid-name
+ ("sum", np.add),
+ ("prod", np.multiply),
+ ("min", np.minimum),
+ ("max", np.maximum),
+ ("avg", lambda a, b: (a + b) * 0.5),
+ ]:
+ dst_array_1 = sess.empty((3, 4), "float32")
+ dst_array_2 = sess.empty((5, 6), "float32")
+ sess.allreduce(d_array_1, dst_array_1, op=op, in_group=True)
+ sess.allreduce(d_array_2, dst_array_2, op=op, in_group=True)
+ result_1 = dst_array_1.debug_get_from_remote(0).numpy()
+ result_2 = dst_array_2.debug_get_from_remote(2).numpy()
+ expected_1 = np_op(array_1, array_2)
+ expected_2 = np_op(array_3, array_4)
+ np.testing.assert_equal(result_1, expected_1)
+ np.testing.assert_equal(result_2, expected_2)
+
+
@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_allgather(session_kind, ccl):
@@ -101,10 +137,47 @@ def test_allgather(session_kind, ccl):
)
[email protected]("session_kind", _all_session_kinds)
[email protected]("ccl", _ccl)
+def test_group_allgather(session_kind, ccl):
+ devices = [0, 1, 2, 3]
+ sess = session_kind(num_workers=len(devices), num_groups=2)
+ sess.init_ccl(ccl, *devices)
+
+ array_1 = np.arange(36, dtype="float32")
+ array_2 = np.arange(48, dtype="float32")
+ d_src_1 = sess.empty((3, 3, 2), "float32")
+ d_dst_1 = sess.empty((3, 4, 3), "float32")
+ d_src_2 = sess.empty((2, 4, 3), "float32")
+ d_dst_2 = sess.empty((2, 6, 4), "float32")
+ d_src_1.debug_copy_from(0, array_1[:18])
+ d_src_1.debug_copy_from(1, array_1[18:])
+ d_src_2.debug_copy_from(2, array_2[:24])
+ d_src_2.debug_copy_from(3, array_2[24:])
+ sess.allgather(d_src_1, d_dst_1, in_group=True)
+ sess.allgather(d_src_2, d_dst_2, in_group=True)
+ np.testing.assert_equal(
+ d_dst_1.debug_get_from_remote(0).numpy(),
+ array_1.reshape(3, 4, 3),
+ )
+ np.testing.assert_equal(
+ d_dst_1.debug_get_from_remote(1).numpy(),
+ array_1.reshape(3, 4, 3),
+ )
+ np.testing.assert_equal(
+ d_dst_2.debug_get_from_remote(2).numpy(),
+ array_2.reshape(2, 6, 4),
+ )
+ np.testing.assert_equal(
+ d_dst_2.debug_get_from_remote(3).numpy(),
+ array_2.reshape(2, 6, 4),
+ )
+
+
@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
@pytest.mark.parametrize("use_explicit_output", [True, False])
-def test_broadcast_from_worker0(session_kind, ccl, use_explicit_output):
+def test_broadcast(session_kind, ccl, use_explicit_output):
devices = [0, 1]
sess = session_kind(num_workers=len(devices))
sess.init_ccl(ccl, *devices)
@@ -123,6 +196,29 @@ def test_broadcast_from_worker0(session_kind, ccl,
use_explicit_output):
np.testing.assert_equal(result, array)
[email protected]("session_kind", _all_session_kinds)
[email protected]("ccl", _ccl)
+def test_group_broadcast(session_kind, ccl):
+ devices = [0, 1, 2, 3]
+ sess = session_kind(num_workers=len(devices), num_groups=2)
+ sess.init_ccl(ccl, *devices)
+
+ array_1 = np.arange(12, dtype="float32").reshape(3, 4)
+ array_2 = np.multiply(array_1, -1)
+
+ src_array = sess.empty((3, 4), "float32", worker0_only=True, in_group=True)
+ src_array.debug_copy_from(0, array_1)
+ src_array.debug_copy_from(2, array_2)
+ dst_array = sess.empty((3, 4), "float32")
+ sess.broadcast_from_worker0(src_array, dst_array)
+
+ result_1 = dst_array.debug_get_from_remote(1).numpy()
+ np.testing.assert_equal(result_1, array_1)
+
+ result_3 = dst_array.debug_get_from_remote(3).numpy()
+ np.testing.assert_equal(result_3, array_2)
+
+
@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
@pytest.mark.parametrize("use_explicit_output", [True, False])
@@ -156,6 +252,45 @@ def test_scatter(session_kind, ccl, use_explicit_output,
capfd):
), "No warning messages should be generated from
disco.Session.scatter_from_worker0"
[email protected]("session_kind", _all_session_kinds)
[email protected]("ccl", _ccl)
+def test_group_scatter(session_kind, ccl, capfd):
+ devices = [0, 1, 2, 3]
+ sess = session_kind(num_workers=len(devices), num_groups=2)
+ sess.init_ccl(ccl, *devices)
+
+ array_1 = np.arange(36, dtype="float32").reshape(2, 6, 3)
+ array_2 = np.multiply(array_1, -1)
+
+ d_src = sess.empty((2, 6, 3), "float32", worker0_only=True, in_group=True)
+ d_src.debug_copy_from(0, array_1)
+ d_src.debug_copy_from(2, array_2)
+ d_dst = sess.empty((6, 3), "float32")
+ sess.scatter_from_worker0(d_src, d_dst)
+
+ np.testing.assert_equal(
+ d_dst.debug_get_from_remote(0).numpy(),
+ array_1[0, :, :],
+ )
+ np.testing.assert_equal(
+ d_dst.debug_get_from_remote(1).numpy(),
+ array_1[1, :, :],
+ )
+ np.testing.assert_equal(
+ d_dst.debug_get_from_remote(2).numpy(),
+ array_2[0, :, :],
+ )
+ np.testing.assert_equal(
+ d_dst.debug_get_from_remote(3).numpy(),
+ array_2[1, :, :],
+ )
+
+ captured = capfd.readouterr()
+ assert (
+ not captured.err
+ ), "No warning messages should be generated from
disco.Session.scatter_from_worker0"
+
+
@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_scatter_with_implicit_reshape(session_kind, ccl, capfd):
@@ -225,6 +360,37 @@ def test_gather(session_kind, ccl, capfd):
), "No warning messages should be generated from
disco.Session.gather_to_worker0"
[email protected]("session_kind", _all_session_kinds)
[email protected]("ccl", _ccl)
+def test_group_gather(session_kind, ccl, capfd):
+ devices = [0, 1, 2, 3]
+ sess = session_kind(num_workers=len(devices), num_groups=2)
+ sess.init_ccl(ccl, *devices)
+
+ array_1 = np.arange(36, dtype="float32")
+ array_2 = np.multiply(array_1, -1)
+ d_src = sess.empty((3, 3, 2), "float32")
+ d_dst = sess.empty((3, 4, 3), "float32", worker0_only=True, in_group=True)
+ d_src.debug_copy_from(0, array_1[:18])
+ d_src.debug_copy_from(1, array_1[18:])
+ d_src.debug_copy_from(2, array_2[:18])
+ d_src.debug_copy_from(3, array_2[18:])
+ sess.gather_to_worker0(d_src, d_dst)
+ np.testing.assert_equal(
+ d_dst.debug_get_from_remote(0).numpy(),
+ array_1.reshape(3, 4, 3),
+ )
+ np.testing.assert_equal(
+ d_dst.debug_get_from_remote(2).numpy(),
+ array_2.reshape(3, 4, 3),
+ )
+
+ captured = capfd.readouterr()
+ assert (
+ not captured.err
+ ), "No warning messages should be generated from
disco.Session.gather_to_worker0"
+
+
@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_mlp(session_kind, ccl): # pylint: disable=too-many-locals
diff --git a/tests/python/disco/test_loader.py
b/tests/python/disco/test_loader.py
index 502cbe0b81..b4e2440857 100644
--- a/tests/python/disco/test_loader.py
+++ b/tests/python/disco/test_loader.py
@@ -22,6 +22,7 @@ import tempfile
import numpy as np
import tvm
+import tvm.testing
from tvm import dlight as dl
from tvm import relax as rx
from tvm._ffi import register_func
@@ -246,7 +247,7 @@ def test_load_shard_in_relax():
@R.function
def main(
loader: R.Object,
- ) -> R.Tuple(R.Tensor((64, 64), "float32"), R.Tensor((16, 128),
"float32"),):
+ ) -> R.Tuple(R.Tensor((64, 64), "float32"), R.Tensor((16, 128),
"float32")):
R.func_attr({"global_symbol": "main"})
with R.dataflow():
lv0: R.Tensor((64, 64), "float32") = R.call_pure_packed(
diff --git a/tests/python/disco/test_session.py
b/tests/python/disco/test_session.py
index ef8ea2e70a..837b3a14f2 100644
--- a/tests/python/disco/test_session.py
+++ b/tests/python/disco/test_session.py
@@ -22,13 +22,14 @@ import numpy as np
import pytest
import tvm
+import tvm.testing
from tvm import relax as rx
from tvm.runtime import ShapeTuple, String
from tvm.runtime import disco as di
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
-from tvm.testing import disco as _
+from tvm.exec import disco_worker as _
def _numpy_to_worker_0(sess: di.Session, np_array: np.array, device):
@@ -168,14 +169,14 @@ def test_vm_multi_func(session_kind):
@T.prim_func
def t1(A: T.Buffer((8, 16), "float32"), B: T.Buffer((16, 8),
"float32")):
for i, j in T.grid(16, 8):
- with T.block("transpose"):
+ with T.block("t1"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vj, vi]
@T.prim_func
def t2(A: T.Buffer((16, 8), "float32"), B: T.Buffer((8, 16),
"float32")):
for i, j in T.grid(8, 16):
- with T.block("transpose"):
+ with T.block("t2"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vj, vi]
@@ -183,7 +184,7 @@ def test_vm_multi_func(session_kind):
def transpose_1(
A: R.Tensor((8, 16), dtype="float32")
) -> R.Tensor((16, 8), dtype="float32"):
- R.func_attr({"global_symbol": "main"})
+ R.func_attr({"global_symbol": "transpose_1"})
cls = TestMod
with R.dataflow():
B = R.call_tir(cls.t1, (A,), out_sinfo=R.Tensor((16, 8),
dtype="float32"))
@@ -194,7 +195,7 @@ def test_vm_multi_func(session_kind):
def transpose_2(
A: R.Tensor((16, 8), dtype="float32")
) -> R.Tensor((8, 16), dtype="float32"):
- R.func_attr({"global_symbol": "main"})
+ R.func_attr({"global_symbol": "transpose_2"})
cls = TestMod
with R.dataflow():
B = R.call_tir(cls.t2, (A,), out_sinfo=R.Tensor((8, 16),
dtype="float32"))
@@ -228,11 +229,4 @@ def test_num_workers(session_kind, num_workers):
if __name__ == "__main__":
- test_int(di.ProcessSession)
- test_float(di.ProcessSession)
- test_string(di.ProcessSession)
- test_string_obj(di.ProcessSession)
- test_shape_tuple(di.ProcessSession)
- test_ndarray(di.ProcessSession)
- test_vm_module(di.ProcessSession)
- test_vm_multi_func(di.ProcessSession)
+ tvm.testing.main()
diff --git
a/tests/python/relax/distributed/test_distributed_transform_lower_global_to_local_view.py
b/tests/python/relax/distributed/test_distributed_transform_lower_global_to_local_view.py
index 3a76f535d7..6ee64a1815 100644
---
a/tests/python/relax/distributed/test_distributed_transform_lower_global_to_local_view.py
+++
b/tests/python/relax/distributed/test_distributed_transform_lower_global_to_local_view.py
@@ -220,7 +220,7 @@ def test_mlp():
out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "R"),
)
lv3: R.DTensor((128, 128), "float32", "mesh[0]", "R") =
R.ccl.allreduce(
- gv, op_type="sum"
+ gv, op_type="sum", in_group=False
)
return lv3
@@ -1559,7 +1559,7 @@ def test_llama_attention():
out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"),
)
lv43: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R") =
R.ccl.allreduce(
- gv, op_type="sum"
+ gv, op_type="sum", in_group=False
)
lv44: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R") =
R.dist.call_tir_local_view(
cls.add,
diff --git a/tests/python/relax/test_transform_legalize_ops_ccl.py
b/tests/python/relax/test_transform_legalize_ops_ccl.py
index 63563ee3c9..9ea4d21d61 100644
--- a/tests/python/relax/test_transform_legalize_ops_ccl.py
+++ b/tests/python/relax/test_transform_legalize_ops_ccl.py
@@ -40,11 +40,11 @@ def test_allreduce():
class Expected:
@R.function
def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10),
dtype="float32"):
- gv0: R.Tensor((10, 10), dtype="float32") =
R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([0])],
out_sinfo=R.Tensor((10, 10), dtype="float32"))
- gv1: R.Tensor((10, 10), dtype="float32") =
R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([1])],
out_sinfo=R.Tensor((10, 10), dtype="float32"))
- gv2: R.Tensor((10, 10), dtype="float32") =
R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([2])],
out_sinfo=R.Tensor((10, 10), dtype="float32"))
- gv3: R.Tensor((10, 10), dtype="float32") =
R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([3])],
out_sinfo=R.Tensor((10, 10), dtype="float32"))
- gv4: R.Tensor((10, 10), dtype="float32") =
R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([4])],
out_sinfo=R.Tensor((10, 10), dtype="float32"))
+ gv0: R.Tensor((10, 10), dtype="float32") =
R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([0]), True],
out_sinfo=R.Tensor((10, 10), dtype="float32"))
+ gv1: R.Tensor((10, 10), dtype="float32") =
R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([1]), True],
out_sinfo=R.Tensor((10, 10), dtype="float32"))
+ gv2: R.Tensor((10, 10), dtype="float32") =
R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([2]), True],
out_sinfo=R.Tensor((10, 10), dtype="float32"))
+ gv3: R.Tensor((10, 10), dtype="float32") =
R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([3]), True],
out_sinfo=R.Tensor((10, 10), dtype="float32"))
+ gv4: R.Tensor((10, 10), dtype="float32") =
R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([4]), True],
out_sinfo=R.Tensor((10, 10), dtype="float32"))
return x
# fmt: on
@@ -66,8 +66,8 @@ def test_allgather():
class Expected:
@R.function
def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10),
dtype="float32"):
- gv0: R.Tensor((20, 10), dtype="float32") =
R.call_dps_packed("runtime.disco.allgather", [x], out_sinfo=R.Tensor((20, 10),
dtype="float32"))
- gv1: R.Tensor((20, 10), dtype="float32") =
R.call_dps_packed("runtime.disco.allgather", [x], out_sinfo=R.Tensor((20, 10),
dtype="float32"))
+ gv0: R.Tensor((20, 10), dtype="float32") =
R.call_dps_packed("runtime.disco.allgather", [x, True], out_sinfo=R.Tensor((20,
10), dtype="float32"))
+ gv1: R.Tensor((20, 10), dtype="float32") =
R.call_dps_packed("runtime.disco.allgather", [x, True], out_sinfo=R.Tensor((20,
10), dtype="float32"))
return x
# fmt: on
@@ -88,7 +88,7 @@ def test_broadcast_from_zero():
class Expected:
@R.function
def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10),
dtype="float32"):
- gv0: R.Tensor((10, 10), dtype="float32") =
R.call_dps_packed("runtime.disco.broadcast_from_worker0", x,
out_sinfo=R.Tensor((10, 10), dtype="float32"))
+ gv0: R.Tensor((10, 10), dtype="float32") =
R.call_dps_packed("runtime.disco.broadcast_from_worker0", [x, False],
out_sinfo=R.Tensor((10, 10), dtype="float32"))
return x
# fmt: on
@@ -134,7 +134,7 @@ def test_scatter_from_worker0():
cls = Expected
gv = R.call_tir(cls.reshape, (x,), out_sinfo=R.Tensor((10, 2, 5),
dtype="float32"))
gv1 = R.call_tir(cls.transpose, (gv,), out_sinfo=R.Tensor((2, 10,
5), dtype="float32"))
- gv0 = R.call_dps_packed("runtime.disco.scatter_from_worker0",
(gv1,), out_sinfo=R.Tensor((10, 5), dtype="float32"))
+ gv0 = R.call_dps_packed("runtime.disco.scatter_from_worker0",
(gv1, False), out_sinfo=R.Tensor((10, 5), dtype="float32"))
return gv0
# fmt: on