This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new bbc97c77fb [Disco] Group-wise operation (#17180)
bbc97c77fb is described below

commit bbc97c77fbd890361a8705c4450057c5c1bfd0db
Author: Yaxing Cai <[email protected]>
AuthorDate: Tue Jul 23 05:52:57 2024 -0700

    [Disco] Group-wise operation (#17180)
    
    This PR introduces the group attribute into Disco, so that group wise
    allreduce and allgather is enabled.
---
 include/tvm/relax/attrs/ccl.h                      |  18 +++
 include/tvm/runtime/disco/builtin.h                |  15 +-
 include/tvm/runtime/disco/disco_worker.h           |   8 +-
 include/tvm/runtime/disco/session.h                |   8 +-
 python/tvm/exec/disco_worker.py                    |  15 +-
 python/tvm/relax/frontend/nn/op.py                 |  13 +-
 python/tvm/relax/op/ccl/ccl.py                     |  24 +--
 python/tvm/relax/transform/legalize_ops/ccl.py     |  10 +-
 python/tvm/runtime/disco/process_pool.py           |  10 +-
 python/tvm/runtime/disco/session.py                | 101 +++++++++----
 src/relax/op/ccl/ccl.cc                            |  22 ++-
 src/relax/op/ccl/ccl.h                             |   4 +-
 src/runtime/disco/builtin.cc                       |  34 +++--
 src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc      |   4 +-
 src/runtime/disco/cuda_ipc/custom_allreduce.cc     |   4 +-
 src/runtime/disco/disco_worker_thread.h            |   4 +-
 src/runtime/disco/loader.cc                        |   8 +-
 src/runtime/disco/nccl/nccl.cc                     | 102 ++++++++-----
 src/runtime/disco/nccl/nccl_context.h              |  13 +-
 src/runtime/disco/process_session.cc               |  21 ++-
 src/runtime/disco/threaded_session.cc              |  16 +-
 tests/python/disco/test_callback.py                |  11 +-
 tests/python/disco/test_ccl.py                     | 168 ++++++++++++++++++++-
 tests/python/disco/test_loader.py                  |   3 +-
 tests/python/disco/test_session.py                 |  20 +--
 ...ributed_transform_lower_global_to_local_view.py |   4 +-
 .../relax/test_transform_legalize_ops_ccl.py       |  18 +--
 27 files changed, 491 insertions(+), 187 deletions(-)

diff --git a/include/tvm/relax/attrs/ccl.h b/include/tvm/relax/attrs/ccl.h
index 42cec88de6..de043f92be 100644
--- a/include/tvm/relax/attrs/ccl.h
+++ b/include/tvm/relax/attrs/ccl.h
@@ -32,14 +32,32 @@ namespace relax {
 /*! \brief Attributes used in allreduce operators */
 struct AllReduceAttrs : public tvm::AttrsNode<AllReduceAttrs> {
   String op_type;
+  bool in_group;
 
   TVM_DECLARE_ATTRS(AllReduceAttrs, "relax.attrs.AllReduceAttrs") {
     TVM_ATTR_FIELD(op_type).describe(
         "The type of reduction operation to be applied to the input data. Now 
only sum is "
         "supported.");
+    TVM_ATTR_FIELD(in_group).describe(
+        "Whether the reduction operation performs in group or globally or in 
group as default.");
   }
 };  // struct AllReduceAttrs
 
+/*! \brief Attributes used in allgather operators */
+struct AllGatherAttrs : public tvm::AttrsNode<AllGatherAttrs> {
+  int num_workers;
+  bool in_group;
+
+  TVM_DECLARE_ATTRS(AllGatherAttrs, "relax.attrs.AllGatherAttrs") {
+    TVM_ATTR_FIELD(num_workers)
+        .describe(
+            "The number of workers, also the number of parts the given buffer 
should be chunked "
+            "into.");
+    TVM_ATTR_FIELD(in_group).describe(
+        "Whether the allgather operation performs in group or globally or in 
group as default.");
+  }
+};  // struct AllGatherAttrs
+
 /*! \brief Attributes used in scatter operators */
 struct ScatterCollectiveAttrs : public tvm::AttrsNode<ScatterCollectiveAttrs> {
   int num_workers;
diff --git a/include/tvm/runtime/disco/builtin.h 
b/include/tvm/runtime/disco/builtin.h
index cf9967dbfe..7d15e35fbd 100644
--- a/include/tvm/runtime/disco/builtin.h
+++ b/include/tvm/runtime/disco/builtin.h
@@ -75,35 +75,40 @@ TVM_DLL NDArray DiscoEmptyNDArray(ShapeTuple shape, 
DataType dtype, Device devic
  * \brief Perform an allreduce operation using the underlying communication 
library
  * \param send The array send to perform allreduce on
  * \param reduce_kind The kind of reduction operation (e.g. sum, avg, min, max)
+ * \param in_group Whether the allreduce operation performs globally or in 
group as default.
  * \param recv The array receives the outcome of allreduce
  */
-TVM_DLL void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv);
+TVM_DLL void AllReduce(NDArray send, ReduceKind reduce_kind, bool in_group, 
NDArray recv);
 /*!
  * \brief Perform an allgather operation using the underlying communication 
library
  * \param send The array send to perform allgather on
+ * \param in_group Whether the allgather operation performs globally or in 
group as default.
  * \param recv The array receives the outcome of allgather
  */
-TVM_DLL void AllGather(NDArray send, NDArray recv);
+TVM_DLL void AllGather(NDArray send, bool in_group, NDArray recv);
 /*!
  * \brief Perform a broadcast operation from worker-0
  * \param send The buffer to be broadcasted
+ * \param in_group Whether the broadcast operation performs globally or in 
group as default.
  * \param recv The buffer receives the broadcasted array
  */
-TVM_DLL void BroadcastFromWorker0(NDArray send, NDArray recv);
+TVM_DLL void BroadcastFromWorker0(NDArray send, bool in_group, NDArray recv);
 /*!
  * \brief Perform a scatter operation from worker-0, chunking the given buffer 
into equal parts.
  * \param send For worker-0, it must be provided, and otherwise, the buffer 
must be None.
  * The buffer will be divided into equal parts and sent to each worker 
accordingly.
+ * \param in_group Whether the scatter operation performs globally or in group 
as default.
  * \param recv The receiving buffer, which must not be None.
  */
-TVM_DLL void ScatterFromWorker0(Optional<NDArray> send, NDArray recv);
+TVM_DLL void ScatterFromWorker0(Optional<NDArray> send, bool in_group, NDArray 
recv);
 /*!
  * \brief Perform a gather operation to worker-0.
  * \param send The sending buffer, which must not be None.
+ * \param in_group Whether the gather operation performs globally or in group 
as default.
  * \param recv For worker-0, it must be provided, and otherwise, the buffer 
must be None. The
  * receiving buffer will be divided into equal parts and receive from each 
worker accordingly.
  */
-TVM_DLL void GatherToWorker0(NDArray send, Optional<NDArray> recv);
+TVM_DLL void GatherToWorker0(NDArray send, bool in_group, Optional<NDArray> 
recv);
 /*!
  * \brief Receive a buffer from worker-0. No-op if the current worker is 
worker-0.
  * \param buffer The buffer to be received
diff --git a/include/tvm/runtime/disco/disco_worker.h 
b/include/tvm/runtime/disco/disco_worker.h
index 14f8f23807..301b5b8d62 100644
--- a/include/tvm/runtime/disco/disco_worker.h
+++ b/include/tvm/runtime/disco/disco_worker.h
@@ -44,14 +44,16 @@ class DiscoWorker {
    * \brief Construct a worker.
    * \param worker_id The id of the worker.
    * \param num_workers The number of the workers.
+   * \param num_groups The number of the worker groups.
    * \param worker_zero_data The data shared between worker-0 and the 
controler. It's a nullptr if
    * the worker is not worker-0.
    * \param channel The communication channel between the worker and the 
controler.
    */
-  explicit DiscoWorker(int worker_id, int num_workers, WorkerZeroData* 
worker_zero_data,
-                       DiscoChannel* channel)
+  explicit DiscoWorker(int worker_id, int num_workers, int num_groups,
+                       WorkerZeroData* worker_zero_data, DiscoChannel* channel)
       : worker_id(worker_id),
         num_workers(num_workers),
+        num_groups(num_groups),
         default_device(Device{DLDeviceType::kDLCPU, 0}),
         worker_zero_data(worker_zero_data),
         channel(channel),
@@ -68,6 +70,8 @@ class DiscoWorker {
   int worker_id;
   /*! \brief Total number of workers */
   int num_workers;
+  /*! \brief Total number of workers */
+  int num_groups;
   /*! \brief The default device to allocate data if not specified */
   Device default_device;
   /*! \brief The name of the underlying collective communication library. */
diff --git a/include/tvm/runtime/disco/session.h 
b/include/tvm/runtime/disco/session.h
index 71fcce75b2..97fa79096d 100644
--- a/include/tvm/runtime/disco/session.h
+++ b/include/tvm/runtime/disco/session.h
@@ -264,11 +264,13 @@ class Session : public ObjectRef {
   /*!
    * \brief Create a session backed by a thread pool of workers
    * \param num_workers The number of workers.
+   * \param num_groups The number of worker groups.
    */
-  TVM_DLL static Session ThreadedSession(int num_workers);
+  TVM_DLL static Session ThreadedSession(int num_workers, int num_groups);
   /*!
    * \brief Create a session backed by pipe-based multiprocessing
    * \param num_workers The number of workers.
+   * \param num_groups The number of worker groups.
    * \param process_pool_creator The name of a global function that takes 
`num_workers` as an input,
    * and returns a PackedFunc, which takes an integer `worker_id` as the input 
and returns None.
    * When `worker-id` is 0, it shuts down the process pool; Otherwise, it 
retursn a tuple
@@ -277,8 +279,8 @@ class Session : public ObjectRef {
    * \note Worker-0 is always co-located with the controler as a separate 
thread, and therefore
    * worker-0 does not exist in the process pool.
    */
-  TVM_DLL static Session ProcessSession(int num_workers, String 
process_pool_creator,
-                                        String entrypoint);
+  TVM_DLL static Session ProcessSession(int num_workers, int num_groups,
+                                        String process_pool_creator, String 
entrypoint);
   TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Session, ObjectRef, 
SessionObj);
 };
 
diff --git a/python/tvm/exec/disco_worker.py b/python/tvm/exec/disco_worker.py
index 76ce0ff993..b1f1554b56 100644
--- a/python/tvm/exec/disco_worker.py
+++ b/python/tvm/exec/disco_worker.py
@@ -99,22 +99,23 @@ def _make_callback(device: tvm.runtime.Device) -> 
Callable[[str, int], NDArray]:
 
 def main():
     """Main worker function"""
-    if len(sys.argv) != 5:
-        print("Usage: <worker_id> <num_workers> <read_fd> <write_fd>")
+    if len(sys.argv) != 6:
+        print("Usage: <worker_id> <num_workers> <num_groups> <read_fd> 
<write_fd>")
         return
     worker_id = int(sys.argv[1])
     num_workers = int(sys.argv[2])
+    num_groups = int(sys.argv[3])
     if sys.platform == "win32":
         import msvcrt  # pylint: disable=import-outside-toplevel,import-error
 
-        reader = msvcrt.open_osfhandle(int(sys.argv[3]), os.O_BINARY)
-        writer = msvcrt.open_osfhandle(int(sys.argv[4]), os.O_BINARY)
+        reader = msvcrt.open_osfhandle(int(sys.argv[4]), os.O_BINARY)
+        writer = msvcrt.open_osfhandle(int(sys.argv[5]), os.O_BINARY)
     else:
-        reader = int(sys.argv[3])
-        writer = int(sys.argv[4])
+        reader = int(sys.argv[4])
+        writer = int(sys.argv[5])
 
     worker_func = get_global_func("runtime.disco.WorkerProcess")
-    worker_func(worker_id, num_workers, reader, writer)
+    worker_func(worker_id, num_workers, num_groups, reader, writer)
 
 
 if __name__ == "__main__":
diff --git a/python/tvm/relax/frontend/nn/op.py 
b/python/tvm/relax/frontend/nn/op.py
index ec072f663c..e1ba4483c7 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -1671,16 +1671,21 @@ def interpolate(
     )
 
 
-def ccl_allreduce(x: Tensor, op_type: str = "sum", name="ccl_allreduce"):
+def ccl_allreduce(x: Tensor, op_type: str = "sum", in_group: bool = True, 
name="ccl_allreduce"):
     """CCL Allreduce operator
 
     Parameters
     ----------
-    x : Tensor
+    x : relax.Expr
       The input tensor.
-    op_type: str
+
+    op_type : str
       The type of reduction operation to be applied to the input data.
       Now "sum", "prod", "min", "max" and "avg" are supported.
+
+    in_group : bool
+      Whether the reduction operation performs globally or in group as default.
+
     name : str
         Name hint for this operation.
 
@@ -1689,7 +1694,7 @@ def ccl_allreduce(x: Tensor, op_type: str = "sum", 
name="ccl_allreduce"):
     result : Tensor
       The result tensor of allreduce.
     """
-    return wrap_nested(_op.ccl.allreduce(x._expr, op_type), name)
+    return wrap_nested(_op.ccl.allreduce(x._expr, op_type, in_group), name)
 
 
 def ccl_broadcast_from_worker0(x: Tensor, name="broadcast_from_worker"):
diff --git a/python/tvm/relax/op/ccl/ccl.py b/python/tvm/relax/op/ccl/ccl.py
index 21c7946120..982c048021 100644
--- a/python/tvm/relax/op/ccl/ccl.py
+++ b/python/tvm/relax/op/ccl/ccl.py
@@ -15,25 +15,26 @@
 # specific language governing permissions and limitations
 # under the License.
 """Relax Collective Communications Library (CCL) operators"""
-from typing import Union
-from tvm.relax import PrimValue
 
 from . import _ffi_api
 from ...expr import Expr
-from ....ir import PrimExpr
 
 
-def allreduce(x, op_type: str = "sum"):  # pylint: disable=invalid-name
+def allreduce(x, op_type: str = "sum", in_group: bool = True):  # pylint: 
disable=invalid-name
     """Allreduce operator
 
     Parameters
     ----------
     x : relax.Expr
       The input tensor.
-    op_type: str
+
+    op_type : str
       The type of reduction operation to be applied to the input data.
       Now "sum", "prod", "min", "max" and "avg" are supported.
 
+    in_group : bool
+      Whether the reduction operation performs globally or in group as default.
+
     Returns
     -------
     result : relax.Expr
@@ -44,10 +45,10 @@ def allreduce(x, op_type: str = "sum"):  # pylint: 
disable=invalid-name
         "Allreduce only supports limited reduction operations, "
         f"including {supported_op_types}, but got {op_type}."
     )
-    return _ffi_api.allreduce(x, op_type)  # type: ignore # pylint: 
disable=no-member
+    return _ffi_api.allreduce(x, op_type, in_group)  # type: ignore # pylint: 
disable=no-member
 
 
-def allgather(x, num_workers: Union[int, PrimExpr, PrimValue]):  # pylint: 
disable=invalid-name
+def allgather(x, num_workers: int, in_group: bool = True):  # pylint: 
disable=invalid-name
     """AllGather operator
 
     Parameters
@@ -55,17 +56,18 @@ def allgather(x, num_workers: Union[int, PrimExpr, 
PrimValue]):  # pylint: disab
     x : relax.Expr
       The input tensor.
 
-    num_worker : Union[int, PrimExpr, PrimValue]
+    num_worker : int
       The number of workers to gather data from.
 
+    in_group : bool
+      Whether the gather operation performs globally or in group as default.
+
     Returns
     -------
     result : relax.Expr
       The result of allgather.
     """
-    if not isinstance(num_workers, PrimValue):
-        num_workers = PrimValue(num_workers)
-    return _ffi_api.allgather(x, num_workers)  # type: ignore # pylint: 
disable=no-member
+    return _ffi_api.allgather(x, num_workers, in_group)  # type: ignore # 
pylint: disable=no-member
 
 
 def broadcast_from_worker0(x: Expr) -> Expr:
diff --git a/python/tvm/relax/transform/legalize_ops/ccl.py 
b/python/tvm/relax/transform/legalize_ops/ccl.py
index ae0be3c228..364dee750e 100644
--- a/python/tvm/relax/transform/legalize_ops/ccl.py
+++ b/python/tvm/relax/transform/legalize_ops/ccl.py
@@ -41,7 +41,7 @@ def _allreduce(_bb: BlockBuilder, call: Call) -> Expr:
         )
     return call_dps_packed(
         "runtime.disco.allreduce",
-        [call.args[0], ShapeExpr([op_type_map[op_type_str]])],
+        [call.args[0], ShapeExpr([op_type_map[op_type_str]]), 
call.attrs.in_group],
         out_sinfo=call.args[0].struct_info,
     )
 
@@ -57,12 +57,12 @@ def _allgather(_bb: BlockBuilder, call: Call) -> Expr:
     arg_shape = arg_sinfo.shape.struct_info
     for i, shape_value in enumerate(arg_shape.values):
         if i == 0:
-            output_shape.append(shape_value * call.args[1].value)
+            output_shape.append(shape_value * call.attrs.num_workers)
         else:
             output_shape.append(shape_value)
     return call_dps_packed(
         "runtime.disco.allgather",
-        call.args[0],
+        [call.args[0], call.attrs.in_group],
         out_sinfo=TensorStructInfo(
             shape=output_shape,
             dtype=arg_sinfo.dtype,
@@ -75,7 +75,7 @@ def _allgather(_bb: BlockBuilder, call: Call) -> Expr:
 def _broadcast_from_worker0(_bb: BlockBuilder, call: Call) -> Expr:
     return call_dps_packed(
         "runtime.disco.broadcast_from_worker0",
-        call.args[0],
+        [call.args[0], False],
         out_sinfo=call.args[0].struct_info,
     )
 
@@ -116,7 +116,7 @@ def _scatter_from_worker0(_bb: BlockBuilder, call: Call) -> 
Expr:
     output_shape = output_shape[1:]
     return call_dps_packed(
         "runtime.disco.scatter_from_worker0",
-        transpose_var,
+        [transpose_var, False],
         out_sinfo=TensorStructInfo(
             shape=output_shape,
             dtype=call.args[0].struct_info.dtype,
diff --git a/python/tvm/runtime/disco/process_pool.py 
b/python/tvm/runtime/disco/process_pool.py
index 1ad8659d60..95969e038e 100644
--- a/python/tvm/runtime/disco/process_pool.py
+++ b/python/tvm/runtime/disco/process_pool.py
@@ -38,6 +38,9 @@ class DiscoPopenWorker:
     num_workers : int
         The total number of workers.
 
+    num_groups : int
+        The total number of worker groups.
+
     stdout: Union[None, int, IO[Any]]
         The standard output streams handler specified for the popen process.
 
@@ -49,12 +52,14 @@ class DiscoPopenWorker:
         self,
         worker_id: int,
         num_workers: int,
+        num_groups: int,
         entrypoint: str = "tvm.exec.disco_worker",
         stdout=None,
         stderr=None,
     ):
         self.worker_id = worker_id
         self.num_workers = num_workers
+        self.num_groups = num_groups
         self.entrypoint = entrypoint
         self._proc = None
         self._stdout = stdout
@@ -118,6 +123,7 @@ class DiscoPopenWorker:
             self.entrypoint,
             str(self.worker_id),
             str(self.num_workers),
+            str(self.num_groups),
         ]
         if sys.platform == "win32":
             import msvcrt  # pylint: 
disable=import-error,import-outside-toplevel
@@ -172,9 +178,9 @@ def _kill_child_processes(pid):
 
 
 @register_func("runtime.disco.create_process_pool")
-def _create_process_pool(num_workers: int, entrypoint: str):
+def _create_process_pool(num_workers: int, num_groups: int, entrypoint: str):
     """Create a process pool where the workers' are [1, num_workers)."""
-    pool = [DiscoPopenWorker(i, num_workers, entrypoint) for i in range(1, 
num_workers)]
+    pool = [DiscoPopenWorker(i, num_workers, num_groups, entrypoint) for i in 
range(1, num_workers)]
 
     def result_func(worker_id: int):
         nonlocal pool
diff --git a/python/tvm/runtime/disco/session.py 
b/python/tvm/runtime/disco/session.py
index ddde1bc1f3..38c4f2a235 100644
--- a/python/tvm/runtime/disco/session.py
+++ b/python/tvm/runtime/disco/session.py
@@ -66,6 +66,7 @@ class DRef(Object):
         ----------
         worker_id : int
             The id of the worker to be copied to.
+
         value : Union[numpy.ndarray, NDArray]
             The value to be copied.
         """
@@ -121,6 +122,7 @@ class Session(Object):
         dtype: str,
         device: Optional[Device] = None,
         worker0_only: bool = False,
+        in_group: bool = True,
     ) -> DRef:
         """Create an empty NDArray on all workers and attach them to a DRef.
 
@@ -139,6 +141,11 @@ class Session(Object):
             If False (default), allocate an array on each worker.  If
             True, only allocate an array on worker0.
 
+        in_group: bool
+            Take effective when `worker0_only` is True. If True (default),
+            allocate an array on each first worker in each group. If
+            False, only allocate an array on worker0 globally.
+
         Returns
         -------
         array : DRef
@@ -148,7 +155,7 @@ class Session(Object):
         if device is None:
             device = Device(device_type=0, device_id=0)
         func = self._get_cached_method("runtime.disco.empty")
-        return func(ShapeTuple(shape), dtype, device, worker0_only)
+        return func(ShapeTuple(shape), dtype, device, worker0_only, in_group)
 
     def shutdown(self):
         """Shut down the Disco session"""
@@ -244,6 +251,7 @@ class Session(Object):
         ----------
         host_array : numpy.ndarray
             The array to be copied to worker-0.
+
         remote_array : NDArray
             The NDArray on worker-0.
         """
@@ -255,11 +263,9 @@ class Session(Object):
         Parameters
         ----------
         host_array : NDArray
-
             The array to be copied to worker-0.
 
         remote_array : Optiona[DRef]
-
             The destination NDArray on worker-0.
 
         Returns
@@ -289,6 +295,7 @@ class Session(Object):
         ----------
         path : str
             The path to the VM module file.
+
         device : Optional[Device] = None
             The device to load the VM module to. Default to the default device 
of each worker.
 
@@ -312,6 +319,7 @@ class Session(Object):
             - nccl
             - rccl
             - mpi
+
         *device_ids : int
             The device IDs to be used by the underlying communication library.
         """
@@ -319,20 +327,23 @@ class Session(Object):
         _ffi_api.SessionInitCCL(self, ccl, ShapeTuple(device_ids))  # type: 
ignore # pylint: disable=no-member
         self._clear_ipc_memory_pool()
 
-    def broadcast(self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = 
None) -> DRef:
+    def broadcast(
+        self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None, 
in_group: bool = True
+    ) -> DRef:
         """Broadcast an array to all workers
 
         Parameters
         ----------
         src: Union[np.ndarray, NDArray]
-
             The array to be broadcasted.
 
         dst: Optional[DRef]
-
             The output array.  If None, an array matching the shape
             and dtype of `src` will be allocated on each worker.
 
+        in_group: bool
+            Whether the broadcast operation performs globally or in group as 
default.
+
         Returns
         -------
         output_array: DRef
@@ -349,38 +360,48 @@ class Session(Object):
             dst = self.empty(src.shape, src.dtype)
 
         src_dref = self.copy_to_worker_0(src)
-        self.broadcast_from_worker0(src_dref, dst)
+        self.broadcast_from_worker0(src_dref, dst, in_group)
 
         return dst
 
-    def broadcast_from_worker0(self, src: DRef, dst: DRef) -> DRef:
+    def broadcast_from_worker0(self, src: DRef, dst: DRef, in_group: bool = 
True) -> DRef:
         """Broadcast an array from worker-0 to all other workers.
 
         Parameters
         ----------
-        array : DRef
-            The array to be broadcasted in-place
+        src: Union[np.ndarray, NDArray]
+            The array to be broadcasted.
+
+        dst: Optional[DRef]
+            The output array.  If None, an array matching the shape
+            and dtype of `src` will be allocated on each worker.
+
+        in_group: bool
+            Whether the broadcast operation performs globally or in group as 
default.
         """
         func = self._get_cached_method("runtime.disco.broadcast_from_worker0")
-        func(src, dst)
+        func(src, in_group, dst)
 
-    def scatter(self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = 
None) -> DRef:
+    def scatter(
+        self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None, 
in_group: bool = True
+    ) -> DRef:
         """Scatter an array across all workers
 
         Parameters
         ----------
         src: Union[np.ndarray, NDArray]
-
             The array to be scattered.  The first dimension of this
             array, `src.shape[0]`, must be equal to the number of
             workers.
 
         dst: Optional[DRef]
-
             The output array.  If None, an array with compatible shape
             and the same dtype as `src` will be allocated on each
             worker.
 
+        in_group: bool
+            Whether the scatter operation performs globally or in group as 
default.
+
         Returns
         -------
         output_array: DRef
@@ -399,41 +420,54 @@ class Session(Object):
             dst = self.empty(src.shape[1:], src.dtype)
 
         src_dref = self.copy_to_worker_0(src)
-        self.scatter_from_worker0(src_dref, dst)
+        self.scatter_from_worker0(src_dref, dst, in_group)
 
         return dst
 
-    def scatter_from_worker0(self, from_array: DRef, to_array: DRef) -> None:
+    def scatter_from_worker0(self, from_array: DRef, to_array: DRef, in_group: 
bool = True) -> None:
         """Scatter an array from worker-0 to all other workers.
 
         Parameters
         ----------
-        from_array : DRef
-            The array to be scattered from.
-        to_array : DRef
-            The array to be scattered to.
+        src: Union[np.ndarray, NDArray]
+            The array to be scattered.  The first dimension of this
+            array, `src.shape[0]`, must be equal to the number of
+            workers.
+
+        dst: Optional[DRef]
+            The output array.  If None, an array with compatible shape
+            and the same dtype as `src` will be allocated on each
+            worker.
+
+        in_group: bool
+            Whether the scatter operation performs globally or in group as 
default.
         """
         func = self._get_cached_method("runtime.disco.scatter_from_worker0")
-        func(from_array, to_array)
+        func(from_array, in_group, to_array)
 
-    def gather_to_worker0(self, from_array: DRef, to_array: DRef) -> None:
+    def gather_to_worker0(self, from_array: DRef, to_array: DRef, in_group: 
bool = True) -> None:
         """Gather an array from all other workers to worker-0.
 
         Parameters
         ----------
         from_array : DRef
             The array to be gathered from.
+
         to_array : DRef
             The array to be gathered to.
+
+        in_group: bool
+            Whether the gather operation performs globally or in group as 
default.
         """
         func = self._get_cached_method("runtime.disco.gather_to_worker0")
-        func(from_array, to_array)
+        func(from_array, in_group, to_array)
 
     def allreduce(
         self,
         src: DRef,
         dst: DRef,
         op: str = "sum",  # pylint: disable=invalid-name
+        in_group: bool = True,
     ) -> DRef:
         """Perform an allreduce operation on an array.
 
@@ -441,6 +475,7 @@ class Session(Object):
         ----------
         array : DRef
             The array to be reduced.
+
         op : str = "sum"
             The reduce operation to be performed. Available options are:
             - "sum"
@@ -448,17 +483,21 @@ class Session(Object):
             - "min"
             - "max"
             - "avg"
+
+        in_group : bool
+            Whether the reduce operation performs globally or in group as 
default.
         """
         if op not in REDUCE_OPS:
             raise ValueError(f"Unsupported reduce op: {op}. Available ops are: 
{REDUCE_OPS.keys()}")
         op = ShapeTuple([REDUCE_OPS[op]])
         func = self._get_cached_method("runtime.disco.allreduce")
-        func(src, op, dst)
+        func(src, op, in_group, dst)
 
     def allgather(
         self,
         src: DRef,
         dst: DRef,
+        in_group: bool = True,
     ) -> DRef:
         """Perform an allgather operation on an array.
 
@@ -466,11 +505,15 @@ class Session(Object):
         ----------
         src : DRef
             The array to be gathered from.
+
         dst : DRef
             The array to be gathered to.
+
+        in_group : bool
+            Whether the reduce operation performs globally or in group as 
default.
         """
         func = self._get_cached_method("runtime.disco.allgather")
-        func(src, dst)
+        func(src, in_group, dst)
 
     def _clear_ipc_memory_pool(self):
         # Clear the IPC memory allocator when the allocator exists.
@@ -483,11 +526,12 @@ class Session(Object):
 class ThreadedSession(Session):
     """A Disco session backed by multi-threading."""
 
-    def __init__(self, num_workers: int) -> None:
+    def __init__(self, num_workers: int, num_groups: int = 1) -> None:
         """Create a disco session backed by multiple threads in the same 
process."""
         self.__init_handle_by_constructor__(
             _ffi_api.SessionThreaded,  # type: ignore # pylint: 
disable=no-member
             num_workers,
+            num_groups,
         )
 
 
@@ -495,10 +539,13 @@ class ThreadedSession(Session):
 class ProcessSession(Session):
     """A Disco session backed by pipe-based multi-processing."""
 
-    def __init__(self, num_workers: int, entrypoint: str = 
"tvm.exec.disco_worker") -> None:
+    def __init__(
+        self, num_workers: int, num_groups: int = 1, entrypoint: str = 
"tvm.exec.disco_worker"
+    ) -> None:
         self.__init_handle_by_constructor__(
             _ffi_api.SessionProcess,  # type: ignore # pylint: 
disable=no-member
             num_workers,
+            num_groups,
             "runtime.disco.create_process_pool",
             entrypoint,
         )
diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc
index c0fe6f4d88..092727cb51 100644
--- a/src/relax/op/ccl/ccl.cc
+++ b/src/relax/op/ccl/ccl.cc
@@ -27,9 +27,10 @@ namespace relax {
 /* relax.ccl.allreduce */
 TVM_REGISTER_NODE_TYPE(AllReduceAttrs);
 
-Expr allreduce(Expr x, String op_type) {
+Expr allreduce(Expr x, String op_type, bool in_group) {
   ObjectPtr<AllReduceAttrs> attrs = make_object<AllReduceAttrs>();
   attrs->op_type = std::move(op_type);
+  attrs->in_group = std::move(in_group);
 
   static const Op& op = Op::Get("relax.ccl.allreduce");
   return Call(op, {std::move(x)}, Attrs{attrs}, {});
@@ -51,19 +52,24 @@ TVM_REGISTER_OP("relax.ccl.allreduce")
     .set_attr<Bool>("FPurity", Bool(true));
 
 /* relax.ccl.allgather */
-Expr allgather(Expr x, Expr num_workers) {
+TVM_REGISTER_NODE_TYPE(AllGatherAttrs);
+
+Expr allgather(Expr x, int num_workers, bool in_group) {
+  ObjectPtr<AllGatherAttrs> attrs = make_object<AllGatherAttrs>();
+  attrs->num_workers = std::move(num_workers);
+  attrs->in_group = std::move(in_group);
+
   static const Op& op = Op::Get("relax.ccl.allgather");
-  return Call(op, {std::move(x), std::move(num_workers)});
+  return Call(op, {std::move(x)}, Attrs{attrs}, {});
 }
 
 TVM_REGISTER_GLOBAL("relax.op.ccl.allgather").set_body_typed(allgather);
 
 StructInfo InferStructInfoAllGather(const Call& call, const BlockBuilder& ctx) 
{
-  CHECK_EQ(call->args.size(), 2);
-  auto input_sinfo = Downcast<TensorStructInfo>(call->args[0]->struct_info_);
-  auto num_workers_sinfo = 
Downcast<PrimStructInfo>(call->args[1]->struct_info_);
+  TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
 
-  auto num_workers = num_workers_sinfo->value;
+  const auto* attrs = call->attrs.as<AllGatherAttrs>();
+  int num_workers = attrs->num_workers;
 
   DataType output_dtype = input_sinfo->dtype;
   auto input_shape = input_sinfo->GetShape();
@@ -71,7 +77,7 @@ StructInfo InferStructInfoAllGather(const Call& call, const 
BlockBuilder& ctx) {
     return input_sinfo;
   }
   Array<PrimExpr> output_shape = input_shape.value();
-  output_shape.Set(0, floor(output_shape[0] * num_workers.value()));
+  output_shape.Set(0, floor(output_shape[0] * num_workers));
   return TensorStructInfo(ShapeExpr(output_shape), output_dtype, 
input_sinfo->vdevice);
 }
 
diff --git a/src/relax/op/ccl/ccl.h b/src/relax/op/ccl/ccl.h
index 3e7f0220c9..82ea393567 100644
--- a/src/relax/op/ccl/ccl.h
+++ b/src/relax/op/ccl/ccl.h
@@ -33,10 +33,10 @@ namespace tvm {
 namespace relax {
 
 /*! \brief AllReduce. */
-Expr allreduce(Expr data, String op_type);
+Expr allreduce(Expr data, String op_type, bool in_group);
 
 /*! \brief AllGather. */
-Expr allgather(Expr data, Expr num_workers);
+Expr allgather(Expr data, int num_workers, bool in_group);
 
 /*! \brief Broadcast data from worker-0 to all other workers. */
 Expr broadcast_from_worker0(Expr data);
diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc
index 26d1c22ee9..0cb2ee6f5d 100644
--- a/src/runtime/disco/builtin.cc
+++ b/src/runtime/disco/builtin.cc
@@ -79,22 +79,24 @@ const PackedFunc& GetCCLFunc(const char* name) {
   return *pf;
 }
 
-void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv) {
-  GetCCLFunc("allreduce")(send, static_cast<int>(reduce_kind), recv);
+void AllReduce(NDArray send, ReduceKind reduce_kind, bool in_group, NDArray 
recv) {
+  GetCCLFunc("allreduce")(send, static_cast<int>(reduce_kind), in_group, recv);
 }
 
-void AllGather(NDArray send, NDArray recv) { GetCCLFunc("allgather")(send, 
recv); }
+void AllGather(NDArray send, bool in_group, NDArray recv) {
+  GetCCLFunc("allgather")(send, in_group, recv);
+}
 
-TVM_DLL void BroadcastFromWorker0(NDArray send, NDArray recv) {
-  GetCCLFunc("broadcast_from_worker0")(send, recv);
+TVM_DLL void BroadcastFromWorker0(NDArray send, bool in_group, NDArray recv) {
+  GetCCLFunc("broadcast_from_worker0")(send, in_group, recv);
 }
 
-TVM_DLL void ScatterFromWorker0(Optional<NDArray> send, NDArray recv) {
-  GetCCLFunc("scatter_from_worker0")(send, recv);
+TVM_DLL void ScatterFromWorker0(Optional<NDArray> send, bool in_group, NDArray 
recv) {
+  GetCCLFunc("scatter_from_worker0")(send, in_group, recv);
 }
 
-void GatherToWorker0(NDArray send, Optional<NDArray> recv) {
-  GetCCLFunc("gather_to_worker0")(send, recv);
+void GatherToWorker0(NDArray send, bool in_group, Optional<NDArray> recv) {
+  GetCCLFunc("gather_to_worker0")(send, in_group, recv);
 }
 
 void RecvFromWorker0(NDArray buffer) { 
GetCCLFunc("recv_from_worker0")(buffer); }
@@ -110,9 +112,13 @@ void SyncWorker() {
 
TVM_REGISTER_GLOBAL("runtime.disco.load_vm_module").set_body_typed(LoadVMModule);
 
 TVM_REGISTER_GLOBAL("runtime.disco.empty")
-    .set_body_typed([](ShapeTuple shape, DataType dtype, Device device,
-                       bool worker0_only) -> Optional<NDArray> {
-      if (worker0_only && WorkerId()) {
+    .set_body_typed([](ShapeTuple shape, DataType dtype, Device device, bool 
worker0_only,
+                       bool in_group) -> Optional<NDArray> {
+      int worker_id = WorkerId();
+      int group_size =
+          DiscoWorker::ThreadLocal()->num_workers / 
DiscoWorker::ThreadLocal()->num_groups;
+      bool is_worker0 = (worker_id == 0 && !in_group) || (in_group && 
worker_id % group_size == 0);
+      if (worker0_only && !is_worker0) {
         return NullOpt;
       } else {
         return DiscoEmptyNDArray(shape, dtype, device);
@@ -120,10 +126,10 @@ TVM_REGISTER_GLOBAL("runtime.disco.empty")
     });
 
 TVM_REGISTER_GLOBAL("runtime.disco.allreduce")
-    .set_body_typed([](NDArray send, ShapeTuple reduce_kind, NDArray recv) {
+    .set_body_typed([](NDArray send, ShapeTuple reduce_kind, bool in_group, 
NDArray recv) {
       int kind = IntegerFromShapeTuple(reduce_kind);
       CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " << 
kind;
-      AllReduce(send, static_cast<ReduceKind>(kind), recv);
+      AllReduce(send, static_cast<ReduceKind>(kind), in_group, recv);
     });
 TVM_REGISTER_GLOBAL("runtime.disco.allgather").set_body_typed(AllGather);
 
TVM_REGISTER_GLOBAL("runtime.disco.broadcast_from_worker0").set_body_typed(BroadcastFromWorker0);
diff --git a/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc 
b/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc
index fec5abec86..490217d62c 100644
--- a/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc
+++ b/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc
@@ -47,8 +47,8 @@ std::vector<cudaIpcMemHandle_t> 
AllGatherIPCHandles(nccl::CCLThreadLocalContext*
   CUDA_CALL(cudaMalloc(&d_src, CUDA_IPC_HANDLE_SIZE));
   CUDA_CALL(cudaMalloc(&d_dst, CUDA_IPC_HANDLE_SIZE * 
ctx->worker->num_workers));
   CUDA_CALL(cudaMemcpy(d_src, &local_handle, CUDA_IPC_HANDLE_SIZE, 
cudaMemcpyHostToDevice));
-  NCCL_CALL(
-      ncclAllGather(d_src, d_dst, CUDA_IPC_HANDLE_SIZE, ncclChar, ctx->comm, 
/*stream=*/nullptr));
+  NCCL_CALL(ncclAllGather(d_src, d_dst, CUDA_IPC_HANDLE_SIZE, ncclChar, 
ctx->global_comm,
+                          /*stream=*/nullptr));
   std::vector<char> serial_handles(CUDA_IPC_HANDLE_SIZE * 
ctx->worker->num_workers, 0);
   CUDA_CALL(cudaMemcpy(serial_handles.data(), d_dst,
                        CUDA_IPC_HANDLE_SIZE * ctx->worker->num_workers, 
cudaMemcpyDefault));
diff --git a/src/runtime/disco/cuda_ipc/custom_allreduce.cc 
b/src/runtime/disco/cuda_ipc/custom_allreduce.cc
index 98fd777b83..d969005f94 100644
--- a/src/runtime/disco/cuda_ipc/custom_allreduce.cc
+++ b/src/runtime/disco/cuda_ipc/custom_allreduce.cc
@@ -65,6 +65,8 @@ inline bool CanApplyTwoShotAllReduce(int64_t num_elements, 
DLDataType dtype, int
 void CustomAllReduce(DLTensor* send, int strategy, DLTensor* recv) {
   int64_t num_elements = TensorSize(send);
   nccl::CCLThreadLocalContext* ctx = nccl::CCLThreadLocalContext::Get();
+  CHECK_EQ(ctx->worker->num_groups, 1)
+      << "Custom AllReduce for multiple group is not yet implemented.";
 
   tensorrt_llm::AllReduceStrategyType strategy_ =
       static_cast<tensorrt_llm::AllReduceStrategyType>(strategy);
@@ -79,7 +81,7 @@ void CustomAllReduce(DLTensor* send, int strategy, DLTensor* 
recv) {
     deviceStream_t stream = ctx->GetDefaultStream();
     NCCL_CALL(ncclAllReduce(send->data, recv->data, num_elements,
                             
/*datatype=*/nccl::AsNCCLDataType(DataType(send->dtype)),
-                            /*op=*/ncclSum, ctx->comm, stream));
+                            /*op=*/ncclSum, ctx->global_comm, stream));
     return;
   }
 
diff --git a/src/runtime/disco/disco_worker_thread.h 
b/src/runtime/disco/disco_worker_thread.h
index 67742cdd04..8d6b44396f 100644
--- a/src/runtime/disco/disco_worker_thread.h
+++ b/src/runtime/disco/disco_worker_thread.h
@@ -47,12 +47,14 @@ class DiscoWorkerThread {
    * \brief Construct a worker thread.
    * \param worker_id The id of the worker.
    * \param num_workers The total number of workers.
+   * \param num_groups The total number of worker groups.
    * \param worker_zero_data_ The data shared between worker-0 and the 
controler. It's a nullptr if
    * the worker is not worker-0.
    * \note This method is implemented in threaded worker, because it depends 
on creation of a
    * sub-class of DiscoChannel, DiscoThreadChannel, which is hidden from the 
public interface.
    */
-  explicit DiscoWorkerThread(int worker_id, int num_workers, WorkerZeroData* 
worker_zero_data_);
+  explicit DiscoWorkerThread(int worker_id, int num_workers, int num_groups,
+                             WorkerZeroData* worker_zero_data_);
 
   /*! \brief Move constructor. */
   explicit DiscoWorkerThread(DiscoWorkerThread&& other)
diff --git a/src/runtime/disco/loader.cc b/src/runtime/disco/loader.cc
index 7a5d978946..efe42539cb 100644
--- a/src/runtime/disco/loader.cc
+++ b/src/runtime/disco/loader.cc
@@ -326,19 +326,19 @@ NDArray ShardLoaderObj::Load(int weight_index) const {
       for (const ShardInfo::ShardFunc& shard_func : 
param_info.shard_info.funcs) {
         w = this->ApplyShardFunc(shard_func, w);
       }
-      ScatterFromWorker0(w, recv);
+      ScatterFromWorker0(w, /*in_group=*/false, recv);
     } else {
-      ScatterFromWorker0(NullOpt, recv);
+      ScatterFromWorker0(NullOpt, /*in_group=*/false, recv);
     }
     return recv;
   } else {
     if (worker_id == 0) {
       NDArray w = LoadDirect(weight_index);
-      BroadcastFromWorker0(w, w);
+      BroadcastFromWorker0(w, /*in_group=*/false, w);
       return w;
     } else {
       NDArray w = NDArray::Empty(param->shape, param->dtype, device);
-      BroadcastFromWorker0(w, w);
+      BroadcastFromWorker0(w, /*in_group=*/false, w);
       return w;
     }
   }
diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc
index bba42ed3bd..2d2c528b52 100644
--- a/src/runtime/disco/nccl/nccl.cc
+++ b/src/runtime/disco/nccl/nccl.cc
@@ -72,9 +72,12 @@ void InitCCLPerWorker(IntTuple device_ids, std::string 
unique_id_bytes) {
       << "ValueError: The length of unique_id must be " << 
NCCL_UNIQUE_ID_BYTES << ", but got "
       << unique_id_bytes.size() << ".";
 
-  CHECK(!ctx->comm) << "Cannot initialize CCL, "
-                    << "the previous thread-global comm still exists, "
-                    << "and has not been destructed";
+  CHECK(!ctx->global_comm) << "Cannot initialize CCL, "
+                           << "the previous thread-global comm still exists, "
+                           << "and has not been destructed";
+  CHECK(!ctx->group_comm) << "Cannot initialize CCL, "
+                          << "the previous thread-group comm still exists, "
+                          << "and has not been destructed";
   CHECK(!ctx->default_stream) << "Cannot initialize CCL, "
                               << "the previous thread-global stream still 
exists, "
                               << "and has not been destructed";
@@ -96,34 +99,41 @@ void InitCCLPerWorker(IntTuple device_ids, std::string 
unique_id_bytes) {
   // Initialize the communicator
   ncclUniqueId id;
   std::memcpy(id.internal, unique_id_bytes.data(), NCCL_UNIQUE_ID_BYTES);
-  NCCL_CALL(ncclCommInitRank(&ctx->comm, worker->num_workers, id, 
worker->worker_id));
+  int group_size = worker->num_workers / worker->num_groups;
+  NCCL_CALL(ncclCommInitRank(&ctx->global_comm, worker->num_workers, id, 
worker->worker_id));
+  NCCL_CALL(ncclCommSplit(ctx->global_comm, worker->worker_id / group_size,
+                          worker->worker_id % group_size, &ctx->group_comm, 
NULL));
 }
 
-void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv) {
+void AllReduce(NDArray send, ReduceKind reduce_kind, bool in_group, NDArray 
recv) {
   CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
   ShapeTuple shape = send.Shape();
   int64_t numel = shape->Product();
   deviceStream_t stream = ctx->GetDefaultStream();
   NCCL_CALL(ncclAllReduce(send->data, recv->data, numel,
                           /*datatype=*/AsNCCLDataType(DataType(send->dtype)),
-                          /*op=*/AsNCCLRedOp(reduce_kind), ctx->comm, stream));
+                          /*op=*/AsNCCLRedOp(reduce_kind),
+                          in_group ? ctx->group_comm : ctx->global_comm, 
stream));
 }
 
-void AllGather(NDArray send, NDArray recv) {
+void AllGather(NDArray send, bool in_group, NDArray recv) {
   CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
   ShapeTuple shape = send.Shape();
   int64_t numel = shape->Product();
   deviceStream_t stream = ctx->GetDefaultStream();
   NCCL_CALL(ncclAllGather(send->data, recv->data, numel,
-                          /*datatype=*/AsNCCLDataType(DataType(send->dtype)), 
ctx->comm, stream));
+                          /*datatype=*/AsNCCLDataType(DataType(send->dtype)),
+                          in_group ? ctx->group_comm : ctx->global_comm, 
stream));
 }
 
-void BroadcastFromWorker0(Optional<NDArray> send, NDArray recv) {
+void BroadcastFromWorker0(Optional<NDArray> send, bool in_group, NDArray recv) 
{
   CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
+  int worker_id = ctx->worker->worker_id;
+  int group_size = ctx->worker->num_workers / ctx->worker->num_groups;
+  bool is_sender = (worker_id == 0 && !in_group) || (in_group && worker_id % 
group_size == 0);
 
   const void* send_data = [&]() -> const void* {
-    int worker_id = ctx->worker->worker_id;
-    if (worker_id == 0) {
+    if (is_sender) {
       CHECK(send.defined());
       CHECK(send.value().Shape()->Product() == recv.Shape()->Product());
       return send.value()->data;
@@ -136,25 +146,28 @@ void BroadcastFromWorker0(Optional<NDArray> send, NDArray 
recv) {
   deviceStream_t stream = ctx->GetDefaultStream();
   NCCL_CALL(ncclBroadcast(send_data, recv->data, numel,
                           /*datatype=*/AsNCCLDataType(DataType(recv->dtype)),
-                          /*root=*/0, ctx->comm, stream));
+                          /*root=*/0, in_group ? ctx->group_comm : 
ctx->global_comm, stream));
 }
 
-void ScatterFromWorker0(Optional<NDArray> send, NDArray recv) {
+void ScatterFromWorker0(Optional<NDArray> send, bool in_group, NDArray recv) {
   CHECK(recv.defined()) << "ValueError: buffer `recv` must not be None";
   CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
   int worker_id = ctx->worker->worker_id;
   int num_workers = ctx->worker->num_workers;
+  int group_size = num_workers / ctx->worker->num_groups;
+  bool is_sender = (worker_id == 0 && !in_group) || (in_group && worker_id % 
group_size == 0);
+  int num_receiver = in_group ? group_size : num_workers;
   deviceStream_t stream = ctx->GetDefaultStream();
-  if (worker_id == 0) {
+  if (is_sender) {
     CHECK(send.defined()) << "ValueError: buffer `send` must be provided when 
worker_id == 0.";
     NDArray buffer = send.value();
     int64_t numel = buffer.Shape()->Product();
-    CHECK_EQ(numel % num_workers, 0) << "ValueError: Scattering evenly 
requires that the number "
-                                        "of elements in the buffer to be "
-                                        "divisible by the number of workers, 
but got numel = "
-                                     << numel << " and " << num_workers << " 
workers.";
+    CHECK_EQ(numel % num_receiver, 0) << "ValueError: Scattering evenly 
requires that the number "
+                                         "of elements in the buffer to be "
+                                         "divisible by the number of workers, 
but got numel = "
+                                      << numel << " and " << num_receiver << " 
workers.";
     DataType dtype(buffer->dtype);
-    int64_t numel_per_shard = numel / num_workers;
+    int64_t numel_per_shard = numel / num_receiver;
     int64_t bytes_per_shard = numel_per_shard * dtype.bytes();
     CHECK_EQ(numel_per_shard, recv.Shape()->Product())
         << "ValueError: The number of elements in buffer `recv` must be the 
same as each shard "
@@ -163,40 +176,45 @@ void ScatterFromWorker0(Optional<NDArray> send, NDArray 
recv) {
         << numel << ", but `recv.size` is " << recv.Shape()->Product() << ".";
     NCCL_CALL(ncclGroupStart());
     uint8_t* data = static_cast<uint8_t*>(buffer->data);
-    for (int i = 0; i < num_workers; ++i) {
-      NCCL_CALL(ncclSend(data, numel_per_shard, AsNCCLDataType(dtype), i, 
ctx->comm, stream));
+    for (int i = 0; i < num_receiver; ++i) {
+      NCCL_CALL(ncclSend(data, numel_per_shard, AsNCCLDataType(dtype), i,
+                         in_group ? ctx->group_comm : ctx->global_comm, 
stream));
       data += bytes_per_shard;
     }
   } else {
     if (send.defined()) {
-      LOG(WARNING) << "Buffer `send` must be None when worker_id != 0, but got 
"
-                      "send = "
+      LOG(WARNING) << "ValueError: buffer `send` must be None when (worker_id 
!= 0 && !in_group) "
+                      "or (worker_id % group_size != 0 && in_group). However, 
got send = "
                    << send.get() << ". This will be ignored.";
     }
     NCCL_CALL(ncclGroupStart());
   }
   int64_t numel = recv.Shape()->Product();
   DataType dtype(recv->dtype);
-  NCCL_CALL(ncclRecv(recv->data, numel, AsNCCLDataType(dtype), 0, ctx->comm, 
stream));
+  NCCL_CALL(ncclRecv(recv->data, numel, AsNCCLDataType(dtype), 0,
+                     in_group ? ctx->group_comm : ctx->global_comm, stream));
   NCCL_CALL(ncclGroupEnd());
 }
 
-void GatherToWorker0(NDArray send, Optional<NDArray> recv) {
+void GatherToWorker0(NDArray send, bool in_group, Optional<NDArray> recv) {
   CHECK(send.defined()) << "ValueError: buffer `send` must not be None";
   CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
   int worker_id = ctx->worker->worker_id;
   int num_workers = ctx->worker->num_workers;
+  int group_size = num_workers / ctx->worker->num_groups;
+  bool is_sender = (worker_id == 0 && !in_group) || (in_group && worker_id % 
group_size == 0);
+  int num_receiver = in_group ? group_size : num_workers;
   deviceStream_t stream = ctx->GetDefaultStream();
-  if (worker_id == 0) {
+  if (is_sender) {
     CHECK(recv.defined()) << "ValueError: buffer `recv` must be provided when 
worker_id == 0.";
     NDArray buffer = recv.value();
     int64_t numel = buffer.Shape()->Product();
-    CHECK_EQ(numel % num_workers, 0) << "ValueError: Gathering evenly requires 
that the number "
-                                        "of elements in the buffer to be "
-                                        "divisible by the number of workers, 
but got numel = "
-                                     << numel << " and " << num_workers << " 
workers.";
+    CHECK_EQ(numel % num_receiver, 0) << "ValueError: Gathering evenly 
requires that the number "
+                                         "of elements in the buffer to be "
+                                         "divisible by the number of workers, 
but got numel = "
+                                      << numel << " and " << num_receiver << " 
workers.";
     DataType dtype(buffer->dtype);
-    int64_t numel_per_shard = numel / num_workers;
+    int64_t numel_per_shard = numel / num_receiver;
     int64_t bytes_per_shard = numel_per_shard * dtype.bytes();
     CHECK_EQ(numel_per_shard, send.Shape()->Product())
         << "ValueError: The number of elements in buffer `send` must be the 
same as each shard "
@@ -205,21 +223,23 @@ void GatherToWorker0(NDArray send, Optional<NDArray> 
recv) {
         << numel << ", but `send.size` is " << send.Shape()->Product() << ".";
     NCCL_CALL(ncclGroupStart());
     uint8_t* data = static_cast<uint8_t*>(buffer->data);
-    for (int i = 0; i < num_workers; ++i) {
-      NCCL_CALL(ncclRecv(data, numel_per_shard, AsNCCLDataType(dtype), i, 
ctx->comm, stream));
+    for (int i = 0; i < num_receiver; ++i) {
+      NCCL_CALL(ncclRecv(data, numel_per_shard, AsNCCLDataType(dtype), i,
+                         in_group ? ctx->group_comm : ctx->global_comm, 
stream));
       data += bytes_per_shard;
     }
   } else {
     if (recv.defined()) {
-      LOG(WARNING) << "ValueError: buffer `recv` must be None when worker_id 
!= 0. However, got "
-                      "recv = "
+      LOG(WARNING) << "ValueError: buffer `recv` must be None when (worker_id 
!= 0 && !in_group) "
+                      "or (worker_id % group_size != 0 && in_group). However, 
got recv = "
                    << recv.get() << ". This will be ignored.";
     }
     NCCL_CALL(ncclGroupStart());
   }
   int64_t numel = send.Shape()->Product();
   DataType dtype(send->dtype);
-  NCCL_CALL(ncclSend(send->data, numel, AsNCCLDataType(dtype), 0, ctx->comm, 
stream));
+  NCCL_CALL(ncclSend(send->data, numel, AsNCCLDataType(dtype), 0,
+                     in_group ? ctx->group_comm : ctx->global_comm, stream));
   NCCL_CALL(ncclGroupEnd());
 }
 
@@ -230,7 +250,7 @@ void RecvFromWorker0(NDArray buffer) {
       << "ValueError: Worker 0 is not allowed to call RecvFromWorker0.";
   NCCL_CALL(ncclGroupStart());
   NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(), 
AsNCCLDataType(buffer.DataType()), 0,
-                     ctx->comm, stream));
+                     ctx->global_comm, stream));
   NCCL_CALL(ncclGroupEnd());
 }
 
@@ -248,12 +268,14 @@ TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME 
".init_ccl").set_body_ty
 TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl_per_worker")
     .set_body_typed(InitCCLPerWorker);
 TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".allreduce")
-    .set_body_typed([](NDArray send, int kind, NDArray recv) {
+    .set_body_typed([](NDArray send, int kind, bool in_group, NDArray recv) {
       CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " << 
kind;
-      nccl::AllReduce(send, static_cast<ReduceKind>(kind), recv);
+      nccl::AllReduce(send, static_cast<ReduceKind>(kind), in_group, recv);
     });
 TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".allgather")
-    .set_body_typed([](NDArray send, NDArray recv) { nccl::AllGather(send, 
recv); });
+    .set_body_typed([](NDArray send, bool in_group, NDArray recv) {
+      nccl::AllGather(send, in_group, recv);
+    });
 TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME 
".broadcast_from_worker0")
     .set_body_typed(BroadcastFromWorker0);
 TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME 
".scatter_from_worker0")
diff --git a/src/runtime/disco/nccl/nccl_context.h 
b/src/runtime/disco/nccl/nccl_context.h
index 3fb281f2cb..730479b61a 100644
--- a/src/runtime/disco/nccl/nccl_context.h
+++ b/src/runtime/disco/nccl/nccl_context.h
@@ -121,14 +121,19 @@ struct CCLThreadLocalContext {
   DiscoWorker* worker = nullptr;
   int device_id;
   deviceStream_t default_stream = nullptr;
-  ncclComm_t comm = nullptr;
+  ncclComm_t global_comm = nullptr;
+  ncclComm_t group_comm = nullptr;
 
   ~CCLThreadLocalContext() { Clear(); }
 
   void Clear() {
-    if (comm) {
-      NCCL_CALL(ncclCommDestroy(comm));
-      comm = nullptr;
+    if (group_comm) {
+      NCCL_CALL(ncclCommDestroy(group_comm));
+      group_comm = nullptr;
+    }
+    if (global_comm) {
+      NCCL_CALL(ncclCommDestroy(global_comm));
+      global_comm = nullptr;
     }
     if (default_stream) {
       StreamDestroy(default_stream);
diff --git a/src/runtime/disco/process_session.cc 
b/src/runtime/disco/process_session.cc
index 179010db8a..7c8d0796dd 100644
--- a/src/runtime/disco/process_session.cc
+++ b/src/runtime/disco/process_session.cc
@@ -154,9 +154,10 @@ class DiscoProcessChannel final : public DiscoChannel {
 
 class ProcessSessionObj final : public BcastSessionObj {
  public:
-  explicit ProcessSessionObj(int num_workers, PackedFunc process_pool)
+  explicit ProcessSessionObj(int num_workers, int num_groups, PackedFunc 
process_pool)
       : process_pool_(process_pool),
-        worker_0_(std::make_unique<DiscoWorkerThread>(0, num_workers, 
&worker_zero_data_)) {
+        worker_0_(
+            std::make_unique<DiscoWorkerThread>(0, num_workers, num_groups, 
&worker_zero_data_)) {
     std::vector<int64_t> read_fds;
     std::vector<int64_t> write_fds;
     read_fds.reserve(num_workers - 1);
@@ -258,18 +259,24 @@ class ProcessSessionObj final : public BcastSessionObj {
 TVM_REGISTER_OBJECT_TYPE(DiscoDebugObject);
 TVM_REGISTER_OBJECT_TYPE(ProcessSessionObj);
 
-Session Session::ProcessSession(int num_workers, String process_pool_creator, 
String entrypoint) {
+Session Session::ProcessSession(int num_workers, int num_group, String 
process_pool_creator,
+                                String entrypoint) {
+  CHECK_EQ(num_workers % num_group, 0)
+      << "The number of workers should be divisible by the number of worker 
group.";
   const PackedFunc* pf = Registry::Get(process_pool_creator);
   CHECK(pf) << "ValueError: Cannot find function " << process_pool_creator
             << " in the registry. Please check if it is registered.";
-  PackedFunc process_pool = (*pf)(num_workers, entrypoint);
-  auto n = make_object<ProcessSessionObj>(num_workers, process_pool);
+  PackedFunc process_pool = (*pf)(num_workers, num_group, entrypoint);
+  auto n = make_object<ProcessSessionObj>(num_workers, num_group, 
process_pool);
   return Session(n);
 }
 
-void WorkerProcess(int worker_id, int num_workers, int64_t read_fd, int64_t 
write_fd) {
+void WorkerProcess(int worker_id, int num_workers, int num_group, int64_t 
read_fd,
+                   int64_t write_fd) {
+  CHECK_EQ(num_workers % num_group, 0)
+      << "The number of workers should be divisible by the number of worker 
group.";
   DiscoProcessChannel channel(read_fd, write_fd);
-  DiscoWorker worker(worker_id, num_workers, nullptr, &channel);
+  DiscoWorker worker(worker_id, num_workers, num_group, nullptr, &channel);
   worker.MainLoop();
 }
 
diff --git a/src/runtime/disco/threaded_session.cc 
b/src/runtime/disco/threaded_session.cc
index 22f906b809..cc9a311a6b 100644
--- a/src/runtime/disco/threaded_session.cc
+++ b/src/runtime/disco/threaded_session.cc
@@ -133,20 +133,20 @@ class DiscoThreadChannel final : public DiscoChannel {
   DiscoThreadedMessageQueue worker_to_controler_;
 };
 
-DiscoWorkerThread::DiscoWorkerThread(int worker_id, int num_workers,
+DiscoWorkerThread::DiscoWorkerThread(int worker_id, int num_workers, int 
num_groups,
                                      WorkerZeroData* worker_zero_data_)
     : channel(std::make_unique<DiscoThreadChannel>()),
-      worker(
-          std::make_unique<DiscoWorker>(worker_id, num_workers, 
worker_zero_data_, channel.get())),
+      worker(std::make_unique<DiscoWorker>(worker_id, num_workers, num_groups, 
worker_zero_data_,
+                                           channel.get())),
       thread(std::make_unique<std::thread>([worker = this->worker.get()] { 
worker->MainLoop(); })) {
 }
 
 class ThreadedSessionObj final : public BcastSessionObj {
  public:
-  explicit ThreadedSessionObj(int num_workers) {
+  explicit ThreadedSessionObj(int num_workers, int num_groups) {
     for (int i = 0; i < num_workers; ++i) {
       WorkerZeroData* data = (i == 0) ? &worker_zero_data_ : nullptr;
-      workers_.emplace_back(i, num_workers, data);
+      workers_.emplace_back(i, num_workers, num_groups, data);
     }
   }
 
@@ -185,8 +185,10 @@ class ThreadedSessionObj final : public BcastSessionObj {
 
 TVM_REGISTER_OBJECT_TYPE(ThreadedSessionObj);
 
-Session Session::ThreadedSession(int num_workers) {
-  ObjectPtr<ThreadedSessionObj> n = 
make_object<ThreadedSessionObj>(num_workers);
+Session Session::ThreadedSession(int num_workers, int num_group) {
+  CHECK_EQ(num_workers % num_group, 0)
+      << "The number of workers should be divisible by the number of worker 
group.";
+  ObjectPtr<ThreadedSessionObj> n = 
make_object<ThreadedSessionObj>(num_workers, num_group);
   return Session(std::move(n));
 }
 
diff --git a/tests/python/disco/test_callback.py 
b/tests/python/disco/test_callback.py
index 6e2dc9b747..3f8d5e9e52 100644
--- a/tests/python/disco/test_callback.py
+++ b/tests/python/disco/test_callback.py
@@ -30,16 +30,17 @@ from tvm.script import relax as R, tir as T
 
 @tvm.testing.requires_nccl
 def test_callback():
+    """Simulate lazy loading of parameters in a callback
+
+    The output of a lazy parameter loading, which would accept a
+    callback to load the parameters.
+    """
+
     @R.function
     def transform_params(
         rank_arg: R.Prim(value="rank"),
         fget_item: R.Callable([R.Object, R.Prim("int64")], R.Object),
     ):
-        """Simulate lazy loading of parameters in a callback
-
-        The output of a lazy parameter loading, which would accept a
-        callback to load the parameters.
-        """
         rank = T.int64()
 
         A = fget_item(R.str("A"), R.prim_value(0))
diff --git a/tests/python/disco/test_ccl.py b/tests/python/disco/test_ccl.py
index 5831f245df..6c63f64554 100644
--- a/tests/python/disco/test_ccl.py
+++ b/tests/python/disco/test_ccl.py
@@ -78,6 +78,42 @@ def test_allreduce(session_kind, ccl):
         np.testing.assert_equal(result, expected)
 
 
[email protected]("session_kind", _all_session_kinds)
[email protected]("ccl", _ccl)
+def test_group_allreduce(session_kind, ccl):
+    devices = [0, 1, 2, 3]
+    sess = session_kind(num_workers=len(devices), num_groups=2)
+    sess.init_ccl(ccl, *devices)
+
+    array_1 = np.arange(12, dtype="float32").reshape(3, 4)
+    array_2 = np.arange(start=1, stop=-11, step=-1, 
dtype="float32").reshape(3, 4)
+    array_3 = np.arange(30, dtype="float32").reshape(5, 6)
+    array_4 = np.arange(start=1, stop=-29, step=-1, 
dtype="float32").reshape(5, 6)
+    d_array_1 = sess.empty((3, 4), "float32")
+    d_array_2 = sess.empty((5, 6), "float32")
+    d_array_1.debug_copy_from(0, array_1)
+    d_array_1.debug_copy_from(1, array_2)
+    d_array_2.debug_copy_from(2, array_3)
+    d_array_2.debug_copy_from(3, array_4)
+    for op, np_op in [  # pylint: disable=invalid-name
+        ("sum", np.add),
+        ("prod", np.multiply),
+        ("min", np.minimum),
+        ("max", np.maximum),
+        ("avg", lambda a, b: (a + b) * 0.5),
+    ]:
+        dst_array_1 = sess.empty((3, 4), "float32")
+        dst_array_2 = sess.empty((5, 6), "float32")
+        sess.allreduce(d_array_1, dst_array_1, op=op, in_group=True)
+        sess.allreduce(d_array_2, dst_array_2, op=op, in_group=True)
+        result_1 = dst_array_1.debug_get_from_remote(0).numpy()
+        result_2 = dst_array_2.debug_get_from_remote(2).numpy()
+        expected_1 = np_op(array_1, array_2)
+        expected_2 = np_op(array_3, array_4)
+        np.testing.assert_equal(result_1, expected_1)
+        np.testing.assert_equal(result_2, expected_2)
+
+
 @pytest.mark.parametrize("session_kind", _all_session_kinds)
 @pytest.mark.parametrize("ccl", _ccl)
 def test_allgather(session_kind, ccl):
@@ -101,10 +137,47 @@ def test_allgather(session_kind, ccl):
     )
 
 
[email protected]("session_kind", _all_session_kinds)
[email protected]("ccl", _ccl)
+def test_group_allgather(session_kind, ccl):
+    devices = [0, 1, 2, 3]
+    sess = session_kind(num_workers=len(devices), num_groups=2)
+    sess.init_ccl(ccl, *devices)
+
+    array_1 = np.arange(36, dtype="float32")
+    array_2 = np.arange(48, dtype="float32")
+    d_src_1 = sess.empty((3, 3, 2), "float32")
+    d_dst_1 = sess.empty((3, 4, 3), "float32")
+    d_src_2 = sess.empty((2, 4, 3), "float32")
+    d_dst_2 = sess.empty((2, 6, 4), "float32")
+    d_src_1.debug_copy_from(0, array_1[:18])
+    d_src_1.debug_copy_from(1, array_1[18:])
+    d_src_2.debug_copy_from(2, array_2[:24])
+    d_src_2.debug_copy_from(3, array_2[24:])
+    sess.allgather(d_src_1, d_dst_1, in_group=True)
+    sess.allgather(d_src_2, d_dst_2, in_group=True)
+    np.testing.assert_equal(
+        d_dst_1.debug_get_from_remote(0).numpy(),
+        array_1.reshape(3, 4, 3),
+    )
+    np.testing.assert_equal(
+        d_dst_1.debug_get_from_remote(1).numpy(),
+        array_1.reshape(3, 4, 3),
+    )
+    np.testing.assert_equal(
+        d_dst_2.debug_get_from_remote(2).numpy(),
+        array_2.reshape(2, 6, 4),
+    )
+    np.testing.assert_equal(
+        d_dst_2.debug_get_from_remote(3).numpy(),
+        array_2.reshape(2, 6, 4),
+    )
+
+
 @pytest.mark.parametrize("session_kind", _all_session_kinds)
 @pytest.mark.parametrize("ccl", _ccl)
 @pytest.mark.parametrize("use_explicit_output", [True, False])
-def test_broadcast_from_worker0(session_kind, ccl, use_explicit_output):
+def test_broadcast(session_kind, ccl, use_explicit_output):
     devices = [0, 1]
     sess = session_kind(num_workers=len(devices))
     sess.init_ccl(ccl, *devices)
@@ -123,6 +196,29 @@ def test_broadcast_from_worker0(session_kind, ccl, 
use_explicit_output):
     np.testing.assert_equal(result, array)
 
 
[email protected]("session_kind", _all_session_kinds)
[email protected]("ccl", _ccl)
+def test_group_broadcast(session_kind, ccl):
+    devices = [0, 1, 2, 3]
+    sess = session_kind(num_workers=len(devices), num_groups=2)
+    sess.init_ccl(ccl, *devices)
+
+    array_1 = np.arange(12, dtype="float32").reshape(3, 4)
+    array_2 = np.multiply(array_1, -1)
+
+    src_array = sess.empty((3, 4), "float32", worker0_only=True, in_group=True)
+    src_array.debug_copy_from(0, array_1)
+    src_array.debug_copy_from(2, array_2)
+    dst_array = sess.empty((3, 4), "float32")
+    sess.broadcast_from_worker0(src_array, dst_array)
+
+    result_1 = dst_array.debug_get_from_remote(1).numpy()
+    np.testing.assert_equal(result_1, array_1)
+
+    result_3 = dst_array.debug_get_from_remote(3).numpy()
+    np.testing.assert_equal(result_3, array_2)
+
+
 @pytest.mark.parametrize("session_kind", _all_session_kinds)
 @pytest.mark.parametrize("ccl", _ccl)
 @pytest.mark.parametrize("use_explicit_output", [True, False])
@@ -156,6 +252,45 @@ def test_scatter(session_kind, ccl, use_explicit_output, 
capfd):
     ), "No warning messages should be generated from 
disco.Session.scatter_from_worker0"
 
 
[email protected]("session_kind", _all_session_kinds)
[email protected]("ccl", _ccl)
+def test_group_scatter(session_kind, ccl, capfd):
+    devices = [0, 1, 2, 3]
+    sess = session_kind(num_workers=len(devices), num_groups=2)
+    sess.init_ccl(ccl, *devices)
+
+    array_1 = np.arange(36, dtype="float32").reshape(2, 6, 3)
+    array_2 = np.multiply(array_1, -1)
+
+    d_src = sess.empty((2, 6, 3), "float32", worker0_only=True, in_group=True)
+    d_src.debug_copy_from(0, array_1)
+    d_src.debug_copy_from(2, array_2)
+    d_dst = sess.empty((6, 3), "float32")
+    sess.scatter_from_worker0(d_src, d_dst)
+
+    np.testing.assert_equal(
+        d_dst.debug_get_from_remote(0).numpy(),
+        array_1[0, :, :],
+    )
+    np.testing.assert_equal(
+        d_dst.debug_get_from_remote(1).numpy(),
+        array_1[1, :, :],
+    )
+    np.testing.assert_equal(
+        d_dst.debug_get_from_remote(2).numpy(),
+        array_2[0, :, :],
+    )
+    np.testing.assert_equal(
+        d_dst.debug_get_from_remote(3).numpy(),
+        array_2[1, :, :],
+    )
+
+    captured = capfd.readouterr()
+    assert (
+        not captured.err
+    ), "No warning messages should be generated from 
disco.Session.scatter_from_worker0"
+
+
 @pytest.mark.parametrize("session_kind", _all_session_kinds)
 @pytest.mark.parametrize("ccl", _ccl)
 def test_scatter_with_implicit_reshape(session_kind, ccl, capfd):
@@ -225,6 +360,37 @@ def test_gather(session_kind, ccl, capfd):
     ), "No warning messages should be generated from 
disco.Session.gather_to_worker0"
 
 
[email protected]("session_kind", _all_session_kinds)
[email protected]("ccl", _ccl)
+def test_group_gather(session_kind, ccl, capfd):
+    devices = [0, 1, 2, 3]
+    sess = session_kind(num_workers=len(devices), num_groups=2)
+    sess.init_ccl(ccl, *devices)
+
+    array_1 = np.arange(36, dtype="float32")
+    array_2 = np.multiply(array_1, -1)
+    d_src = sess.empty((3, 3, 2), "float32")
+    d_dst = sess.empty((3, 4, 3), "float32", worker0_only=True, in_group=True)
+    d_src.debug_copy_from(0, array_1[:18])
+    d_src.debug_copy_from(1, array_1[18:])
+    d_src.debug_copy_from(2, array_2[:18])
+    d_src.debug_copy_from(3, array_2[18:])
+    sess.gather_to_worker0(d_src, d_dst)
+    np.testing.assert_equal(
+        d_dst.debug_get_from_remote(0).numpy(),
+        array_1.reshape(3, 4, 3),
+    )
+    np.testing.assert_equal(
+        d_dst.debug_get_from_remote(2).numpy(),
+        array_2.reshape(3, 4, 3),
+    )
+
+    captured = capfd.readouterr()
+    assert (
+        not captured.err
+    ), "No warning messages should be generated from 
disco.Session.gather_to_worker0"
+
+
 @pytest.mark.parametrize("session_kind", _all_session_kinds)
 @pytest.mark.parametrize("ccl", _ccl)
 def test_mlp(session_kind, ccl):  # pylint: disable=too-many-locals
diff --git a/tests/python/disco/test_loader.py 
b/tests/python/disco/test_loader.py
index 502cbe0b81..b4e2440857 100644
--- a/tests/python/disco/test_loader.py
+++ b/tests/python/disco/test_loader.py
@@ -22,6 +22,7 @@ import tempfile
 import numpy as np
 
 import tvm
+import tvm.testing
 from tvm import dlight as dl
 from tvm import relax as rx
 from tvm._ffi import register_func
@@ -246,7 +247,7 @@ def test_load_shard_in_relax():
         @R.function
         def main(
             loader: R.Object,
-        ) -> R.Tuple(R.Tensor((64, 64), "float32"), R.Tensor((16, 128), 
"float32"),):
+        ) -> R.Tuple(R.Tensor((64, 64), "float32"), R.Tensor((16, 128), 
"float32")):
             R.func_attr({"global_symbol": "main"})
             with R.dataflow():
                 lv0: R.Tensor((64, 64), "float32") = R.call_pure_packed(
diff --git a/tests/python/disco/test_session.py 
b/tests/python/disco/test_session.py
index ef8ea2e70a..837b3a14f2 100644
--- a/tests/python/disco/test_session.py
+++ b/tests/python/disco/test_session.py
@@ -22,13 +22,14 @@ import numpy as np
 import pytest
 
 import tvm
+import tvm.testing
 from tvm import relax as rx
 from tvm.runtime import ShapeTuple, String
 from tvm.runtime import disco as di
 from tvm.script import ir as I
 from tvm.script import relax as R
 from tvm.script import tir as T
-from tvm.testing import disco as _
+from tvm.exec import disco_worker as _
 
 
 def _numpy_to_worker_0(sess: di.Session, np_array: np.array, device):
@@ -168,14 +169,14 @@ def test_vm_multi_func(session_kind):
         @T.prim_func
         def t1(A: T.Buffer((8, 16), "float32"), B: T.Buffer((16, 8), 
"float32")):
             for i, j in T.grid(16, 8):
-                with T.block("transpose"):
+                with T.block("t1"):
                     vi, vj = T.axis.remap("SS", [i, j])
                     B[vi, vj] = A[vj, vi]
 
         @T.prim_func
         def t2(A: T.Buffer((16, 8), "float32"), B: T.Buffer((8, 16), 
"float32")):
             for i, j in T.grid(8, 16):
-                with T.block("transpose"):
+                with T.block("t2"):
                     vi, vj = T.axis.remap("SS", [i, j])
                     B[vi, vj] = A[vj, vi]
 
@@ -183,7 +184,7 @@ def test_vm_multi_func(session_kind):
         def transpose_1(
             A: R.Tensor((8, 16), dtype="float32")
         ) -> R.Tensor((16, 8), dtype="float32"):
-            R.func_attr({"global_symbol": "main"})
+            R.func_attr({"global_symbol": "transpose_1"})
             cls = TestMod
             with R.dataflow():
                 B = R.call_tir(cls.t1, (A,), out_sinfo=R.Tensor((16, 8), 
dtype="float32"))
@@ -194,7 +195,7 @@ def test_vm_multi_func(session_kind):
         def transpose_2(
             A: R.Tensor((16, 8), dtype="float32")
         ) -> R.Tensor((8, 16), dtype="float32"):
-            R.func_attr({"global_symbol": "main"})
+            R.func_attr({"global_symbol": "transpose_2"})
             cls = TestMod
             with R.dataflow():
                 B = R.call_tir(cls.t2, (A,), out_sinfo=R.Tensor((8, 16), 
dtype="float32"))
@@ -228,11 +229,4 @@ def test_num_workers(session_kind, num_workers):
 
 
 if __name__ == "__main__":
-    test_int(di.ProcessSession)
-    test_float(di.ProcessSession)
-    test_string(di.ProcessSession)
-    test_string_obj(di.ProcessSession)
-    test_shape_tuple(di.ProcessSession)
-    test_ndarray(di.ProcessSession)
-    test_vm_module(di.ProcessSession)
-    test_vm_multi_func(di.ProcessSession)
+    tvm.testing.main()
diff --git 
a/tests/python/relax/distributed/test_distributed_transform_lower_global_to_local_view.py
 
b/tests/python/relax/distributed/test_distributed_transform_lower_global_to_local_view.py
index 3a76f535d7..6ee64a1815 100644
--- 
a/tests/python/relax/distributed/test_distributed_transform_lower_global_to_local_view.py
+++ 
b/tests/python/relax/distributed/test_distributed_transform_lower_global_to_local_view.py
@@ -220,7 +220,7 @@ def test_mlp():
                 out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "R"),
             )
             lv3: R.DTensor((128, 128), "float32", "mesh[0]", "R") = 
R.ccl.allreduce(
-                gv, op_type="sum"
+                gv, op_type="sum", in_group=False
             )
             return lv3
 
@@ -1559,7 +1559,7 @@ def test_llama_attention():
                 out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"),
             )
             lv43: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R") = 
R.ccl.allreduce(
-                gv, op_type="sum"
+                gv, op_type="sum", in_group=False
             )
             lv44: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R") = 
R.dist.call_tir_local_view(
                 cls.add,
diff --git a/tests/python/relax/test_transform_legalize_ops_ccl.py 
b/tests/python/relax/test_transform_legalize_ops_ccl.py
index 63563ee3c9..9ea4d21d61 100644
--- a/tests/python/relax/test_transform_legalize_ops_ccl.py
+++ b/tests/python/relax/test_transform_legalize_ops_ccl.py
@@ -40,11 +40,11 @@ def test_allreduce():
     class Expected:
         @R.function
         def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), 
dtype="float32"):
-            gv0: R.Tensor((10, 10), dtype="float32") = 
R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([0])], 
out_sinfo=R.Tensor((10, 10), dtype="float32"))
-            gv1: R.Tensor((10, 10), dtype="float32") = 
R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([1])], 
out_sinfo=R.Tensor((10, 10), dtype="float32"))
-            gv2: R.Tensor((10, 10), dtype="float32") = 
R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([2])], 
out_sinfo=R.Tensor((10, 10), dtype="float32"))
-            gv3: R.Tensor((10, 10), dtype="float32") = 
R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([3])], 
out_sinfo=R.Tensor((10, 10), dtype="float32"))
-            gv4: R.Tensor((10, 10), dtype="float32") = 
R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([4])], 
out_sinfo=R.Tensor((10, 10), dtype="float32"))
+            gv0: R.Tensor((10, 10), dtype="float32") = 
R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([0]), True], 
out_sinfo=R.Tensor((10, 10), dtype="float32"))
+            gv1: R.Tensor((10, 10), dtype="float32") = 
R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([1]), True], 
out_sinfo=R.Tensor((10, 10), dtype="float32"))
+            gv2: R.Tensor((10, 10), dtype="float32") = 
R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([2]), True], 
out_sinfo=R.Tensor((10, 10), dtype="float32"))
+            gv3: R.Tensor((10, 10), dtype="float32") = 
R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([3]), True], 
out_sinfo=R.Tensor((10, 10), dtype="float32"))
+            gv4: R.Tensor((10, 10), dtype="float32") = 
R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([4]), True], 
out_sinfo=R.Tensor((10, 10), dtype="float32"))
             return x
     # fmt: on
 
@@ -66,8 +66,8 @@ def test_allgather():
     class Expected:
         @R.function
         def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), 
dtype="float32"):
-            gv0: R.Tensor((20, 10), dtype="float32") = 
R.call_dps_packed("runtime.disco.allgather", [x], out_sinfo=R.Tensor((20, 10), 
dtype="float32"))
-            gv1: R.Tensor((20, 10), dtype="float32") = 
R.call_dps_packed("runtime.disco.allgather", [x], out_sinfo=R.Tensor((20, 10), 
dtype="float32"))
+            gv0: R.Tensor((20, 10), dtype="float32") = 
R.call_dps_packed("runtime.disco.allgather", [x, True], out_sinfo=R.Tensor((20, 
10), dtype="float32"))
+            gv1: R.Tensor((20, 10), dtype="float32") = 
R.call_dps_packed("runtime.disco.allgather", [x, True], out_sinfo=R.Tensor((20, 
10), dtype="float32"))
             return x
     # fmt: on
 
@@ -88,7 +88,7 @@ def test_broadcast_from_zero():
     class Expected:
         @R.function
         def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), 
dtype="float32"):
-            gv0: R.Tensor((10, 10), dtype="float32") = 
R.call_dps_packed("runtime.disco.broadcast_from_worker0", x, 
out_sinfo=R.Tensor((10, 10), dtype="float32"))
+            gv0: R.Tensor((10, 10), dtype="float32") = 
R.call_dps_packed("runtime.disco.broadcast_from_worker0", [x, False], 
out_sinfo=R.Tensor((10, 10), dtype="float32"))
             return x
     # fmt: on
 
@@ -134,7 +134,7 @@ def test_scatter_from_worker0():
             cls = Expected
             gv = R.call_tir(cls.reshape, (x,), out_sinfo=R.Tensor((10, 2, 5), 
dtype="float32"))
             gv1 = R.call_tir(cls.transpose, (gv,), out_sinfo=R.Tensor((2, 10, 
5), dtype="float32"))
-            gv0 = R.call_dps_packed("runtime.disco.scatter_from_worker0", 
(gv1,), out_sinfo=R.Tensor((10, 5), dtype="float32"))
+            gv0 = R.call_dps_packed("runtime.disco.scatter_from_worker0", 
(gv1, False), out_sinfo=R.Tensor((10, 5), dtype="float32"))
             return gv0
     # fmt: on
 

Reply via email to