This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 71b81127a2 [Disco][Op] broadcast_from_worker0 (#15633)
71b81127a2 is described below

commit 71b81127a29255ce9280d5397f64de8d82e4f5fb
Author: Lesheng Jin <[email protected]>
AuthorDate: Mon Aug 28 14:26:51 2023 -0700

    [Disco][Op] broadcast_from_worker0 (#15633)
    
    This pr introduces op broadcast_from_zero, which broadcasts input tensor 
from worker-0 to all other workers.
---
 python/tvm/relax/op/ccl/ccl.py                     | 18 ++++++
 python/tvm/relax/transform/legalize_ops/ccl.py     |  9 +++
 python/tvm/runtime/disco/session.py                |  2 +-
 src/relax/op/ccl/ccl.cc                            | 20 ++++++
 src/relax/op/ccl/ccl.h                             |  3 +
 src/runtime/disco/nccl/nccl.cc                     |  3 +-
 tests/python/disco/test_nccl.py                    | 16 ++---
 tests/python/relax/test_op_ccl.py                  | 75 ++++++++++++++++++++++
 .../relax/test_transform_legalize_ops_ccl.py       | 21 ++++++
 9 files changed, 157 insertions(+), 10 deletions(-)

diff --git a/python/tvm/relax/op/ccl/ccl.py b/python/tvm/relax/op/ccl/ccl.py
index 7fede70543..093e01b638 100644
--- a/python/tvm/relax/op/ccl/ccl.py
+++ b/python/tvm/relax/op/ccl/ccl.py
@@ -17,6 +17,8 @@
 """Relax Collective Communications Library (CCL) operators"""
 from . import _ffi_api
 
+from ...expr import Expr
+
 
 def allreduce(x, op_type: str = "sum"):  # pylint: disable=invalid-name
     """Allreduce operator
@@ -40,3 +42,19 @@ def allreduce(x, op_type: str = "sum"):  # pylint: 
disable=invalid-name
         f"including {supported_op_types}, but got {op_type}."
     )
     return _ffi_api.allreduce(x, op_type)  # type: ignore # pylint: 
disable=no-member
+
+
+def broadcast_from_worker0(x: Expr) -> Expr:
+    """Broadcast data from worker-0 to all other workers.
+
+    Parameters
+    ----------
+    x : relax.Expr
+      The tensor to be broadcast.
+
+    Returns
+    -------
+    result : relax.Expr
+      The same tensor, which has been broadcast to all other workers.
+    """
+    return _ffi_api.broadcast_from_worker0(x)
diff --git a/python/tvm/relax/transform/legalize_ops/ccl.py 
b/python/tvm/relax/transform/legalize_ops/ccl.py
index b1df104518..c9c09952b2 100644
--- a/python/tvm/relax/transform/legalize_ops/ccl.py
+++ b/python/tvm/relax/transform/legalize_ops/ccl.py
@@ -43,3 +43,12 @@ def _allreduce(_bb: BlockBuilder, call: Call) -> Expr:
         ShapeExpr([op_type_map[op_type_str]]),
         sinfo_args=call.args[0].struct_info,
     )
+
+
+@register_legalize("relax.ccl.broadcast_from_worker0")
+def broadcast_from_worker0(_bb: BlockBuilder, call: Call) -> Expr:
+    return call_pure_packed(
+        "runtime.disco.broadcast_from_worker0",
+        call.args[0],
+        sinfo_args=call.args[0].struct_info,
+    )
diff --git a/python/tvm/runtime/disco/session.py 
b/python/tvm/runtime/disco/session.py
index e271f72138..587704fe28 100644
--- a/python/tvm/runtime/disco/session.py
+++ b/python/tvm/runtime/disco/session.py
@@ -267,7 +267,7 @@ class Session(Object):
         func = self.get_global_func(f"runtime.disco.{api}.init")
         func(*args)
 
-    def broadcast_from_worker0(self, array: DRef) -> None:
+    def broadcast_from_worker0(self, array: DRef) -> DRef:
         """Broadcast an array from worker-0 to all other workers.
 
         Parameters
diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc
index 40c897532d..6fa6c96db3 100644
--- a/src/relax/op/ccl/ccl.cc
+++ b/src/relax/op/ccl/ccl.cc
@@ -50,5 +50,25 @@ TVM_REGISTER_OP("relax.ccl.allreduce")
     .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise)
     .set_attr<Bool>("FPurity", Bool(true));
 
+/* relax.ccl.broadcast_from_worker0 */
+Expr broadcast_from_worker0(Expr x) {
+  static const Op& op = Op::Get("relax.ccl.broadcast_from_worker0");
+  return Call(op, {std::move(x)}, {}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.ccl.broadcast_from_worker0").set_body_typed(broadcast_from_worker0);
+
+StructInfo InferStructInfoBroadcastFromZero(const Call& call, const 
BlockBuilder& ctx) {
+  TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+  return input_sinfo;
+}
+
+TVM_REGISTER_OP("relax.ccl.broadcast_from_worker0")
+    .set_num_inputs(1)
+    .add_argument("x", "Tensor", "Input to be broadcast.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoBroadcastFromZero)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise)
+    .set_attr<Bool>("FPurity", Bool(true));
+
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/ccl/ccl.h b/src/relax/op/ccl/ccl.h
index f87512c138..55402f3d37 100644
--- a/src/relax/op/ccl/ccl.h
+++ b/src/relax/op/ccl/ccl.h
@@ -35,6 +35,9 @@ namespace relax {
 /*! \brief AllReduce. */
 Expr allreduce(Expr data, String op_type);
 
+/*! \brief Broadcast data from worker-0 to all other workers. */
+Expr broadcast_from_worker0(Expr data);
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc
index f1b6c6c623..48f552c5aa 100644
--- a/src/runtime/disco/nccl/nccl.cc
+++ b/src/runtime/disco/nccl/nccl.cc
@@ -109,13 +109,14 @@ NDArray AllReduce(NDArray send, ReduceKind reduce_kind) {
   return recv;
 }
 
-void BroadcastFromZero(NDArray buffer) {
+NDArray BroadcastFromZero(NDArray buffer) {
   ShapeTuple shape = buffer.Shape();
   int64_t numel = GetNumel(shape);
   NCCL_CALL(ncclBroadcast(buffer->data, buffer->data, numel,
                           
/*datatype=*/AsNCCLDataType(DataType(buffer->dtype)),  //
                           /*root=*/0, 
NCCLGlobalContext::ThreadLocalCommunicator(),
                           NCCLGlobalContext::ThreadLocalStream()));
+  return buffer;
 }
 
 TVM_REGISTER_GLOBAL("runtime.disco.nccl.init").set_body([](TVMArgs args, 
TVMRetValue* rv) -> void {
diff --git a/tests/python/disco/test_nccl.py b/tests/python/disco/test_nccl.py
index 71f13859ab..af7da28db2 100644
--- a/tests/python/disco/test_nccl.py
+++ b/tests/python/disco/test_nccl.py
@@ -60,7 +60,7 @@ def test_allreduce():
         np.testing.assert_equal(result, expected)
 
 
-def test_broadcast_from_zero():
+def test_broadcast_from_worker0():
     num_workers = 2
     devices = [1, 2]
     array = np.arange(12, dtype="float32").reshape(3, 4)
@@ -105,7 +105,8 @@ def test_mlp():  # pylint: disable=too-many-locals
         ) -> R.Tensor((128, 128), "float32"):
             R.func_attr({"global_symbol": "main"})
             with R.dataflow():
-                lv0: R.Tensor((128, 64), "float32") = R.matmul(x, W1)
+                broadcast_x: R.Tensor((128, 128), "float32") = 
R.ccl.broadcast_from_worker0(x)
+                lv0: R.Tensor((128, 64), "float32") = R.matmul(broadcast_x, W1)
                 lv1: R.Tensor((128, 64), "float32") = R.nn.gelu(lv0)
                 lv2: R.Tensor((128, 128), "float32") = R.matmul(lv1, W2)
                 lv3: R.Tensor((128, 128), "float32") = R.ccl.allreduce(lv2, 
"sum")
@@ -159,7 +160,6 @@ def test_mlp():  # pylint: disable=too-many-locals
         d_W2 = sess.empty((64, 128), "float32")
 
         d_X.debug_copy_from(0, X)
-        d_X.debug_copy_from(1, X)
         d_W1.debug_copy_from(0, W1[:, :64])
         d_W1.debug_copy_from(1, W1[:, 64:])
         d_W2.debug_copy_from(0, W2[:64, :])
@@ -230,16 +230,17 @@ def test_attention():  # pylint: disable=too-many-locals
         ) -> R.Tensor((128, 128), "float32"):
             R.func_attr({"global_symbol": "main"})
             with R.dataflow():
+                broadcast_x: R.Tensor((1, 10, 128), "float32") = 
R.ccl.broadcast_from_worker0(x)
                 # q
-                lv0: R.Tensor((1, 10, 256), "float32") = R.matmul(x, Wq)
+                lv0: R.Tensor((1, 10, 256), "float32") = R.matmul(broadcast_x, 
Wq)
                 lv1: R.Tensor((1, 10, 4, 64), "float32") = R.reshape(lv0, [1, 
10, 4, 64])
                 lv2: R.Tensor((1, 4, 10, 64), "float32") = R.permute_dims(lv1, 
[0, 2, 1, 3])
                 # k
-                lv3: R.Tensor((1, 10, 256), "float32") = R.matmul(x, Wk)
+                lv3: R.Tensor((1, 10, 256), "float32") = R.matmul(broadcast_x, 
Wk)
                 lv4: R.Tensor((1, 10, 4, 64), "float32") = R.reshape(lv3, [1, 
10, 4, 64])
                 lv5: R.Tensor((1, 4, 10, 64), "float32") = R.permute_dims(lv4, 
[0, 2, 1, 3])
                 # v
-                lv6: R.Tensor((1, 10, 256), "float32") = R.matmul(x, Wv)
+                lv6: R.Tensor((1, 10, 256), "float32") = R.matmul(broadcast_x, 
Wv)
                 lv7: R.Tensor((1, 10, 4, 64), "float32") = R.reshape(lv6, [1, 
10, 4, 64])
                 lv8: R.Tensor((1, 4, 10, 64), "float32") = R.permute_dims(lv7, 
[0, 2, 1, 3])
                 # softmax(q @ k / sqrt(dk))
@@ -312,7 +313,6 @@ def test_attention():  # pylint: disable=too-many-locals
         d_Wo = sess.empty((256, 128), "float32")
 
         d_X.debug_copy_from(0, X)
-        d_X.debug_copy_from(1, X)
         d_Wq.debug_copy_from(0, Wq[:, :256])
         d_Wq.debug_copy_from(1, Wq[:, 256:])
         d_Wk.debug_copy_from(0, Wk[:, :256])
@@ -332,7 +332,7 @@ def test_attention():  # pylint: disable=too-many-locals
 
 if __name__ == "__main__":
     test_init()
-    test_broadcast_from_zero()
+    test_broadcast_from_worker0()
     test_allreduce()
     test_mlp()
     test_attention()
diff --git a/tests/python/relax/test_op_ccl.py 
b/tests/python/relax/test_op_ccl.py
index fd25b393cb..09924d27ec 100644
--- a/tests/python/relax/test_op_ccl.py
+++ b/tests/python/relax/test_op_ccl.py
@@ -26,6 +26,7 @@ from tvm.script import relax as R
 def test_op_correctness():
     x = relax.Var("x", R.Tensor((2, 3), "float32"))
     assert relax.op.ccl.allreduce(x).op == Op.get("relax.ccl.allreduce")
+    assert relax.op.ccl.broadcast_from_worker0(x).op == 
Op.get("relax.ccl.broadcast_from_worker0")
 
 
 def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: 
relax.StructInfo):
@@ -85,5 +86,79 @@ def test_allreduce_infer_struct_info_more_input_dtype():
     _check_inference(bb, relax.op.ccl.allreduce(x2), 
relax.TensorStructInfo((2, 3), "int64"))
 
 
+def test_broadcast_from_worker0_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+    x2 = relax.Var("x", R.Tensor("float32", ndim=-1))
+    x3 = relax.Var("x", R.Tensor((2, 3)))
+    x4 = relax.Var("x", R.Tensor())
+    x5 = relax.Var("x", R.Tensor((3, 4)))
+
+    _check_inference(
+        bb, relax.op.ccl.broadcast_from_worker0(x0), 
relax.TensorStructInfo((2, 3), "float32")
+    )
+    _check_inference(
+        bb, relax.op.ccl.broadcast_from_worker0(x1), 
relax.TensorStructInfo(dtype="float32", ndim=3)
+    )
+    _check_inference(
+        bb, relax.op.ccl.broadcast_from_worker0(x2), 
relax.TensorStructInfo(dtype="float32")
+    )
+    _check_inference(
+        bb, relax.op.ccl.broadcast_from_worker0(x3), 
relax.TensorStructInfo((2, 3), dtype="")
+    )
+    _check_inference(bb, relax.op.ccl.broadcast_from_worker0(x4), 
relax.TensorStructInfo(dtype=""))
+    _check_inference(
+        bb, relax.op.ccl.broadcast_from_worker0(x5), 
relax.TensorStructInfo((3, 4), dtype="")
+    )
+
+
+def test_broadcast_from_worker0_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    m = tir.Var("m", "int64")
+    n = tir.Var("n", "int64")
+    x0 = relax.Var("x", R.Tensor((m, n), "float32"))
+    x1 = relax.Var("x", R.Tensor((4, n), "float32"))
+
+    _check_inference(
+        bb, relax.op.ccl.broadcast_from_worker0(x0), 
relax.TensorStructInfo((m, n), "float32")
+    )
+    _check_inference(
+        bb, relax.op.ccl.broadcast_from_worker0(x1), 
relax.TensorStructInfo((4, n), "float32")
+    )
+
+
+def test_broadcast_from_worker0_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2))
+    s1 = relax.Var("s", relax.ShapeStructInfo())
+    x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+
+    _check_inference(
+        bb, relax.op.ccl.broadcast_from_worker0(x0), 
relax.TensorStructInfo(s0, "float32")
+    )
+    _check_inference(
+        bb, relax.op.ccl.broadcast_from_worker0(x1), 
relax.TensorStructInfo(s1, "float32")
+    )
+
+
+def test_broadcast_from_worker0_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3), "float64"))
+    x1 = relax.Var("x", R.Tensor((2, 3), "int8"))
+    x2 = relax.Var("x", R.Tensor((2, 3), "int64"))
+
+    _check_inference(
+        bb, relax.op.ccl.broadcast_from_worker0(x0), 
relax.TensorStructInfo((2, 3), "float64")
+    )
+    _check_inference(
+        bb, relax.op.ccl.broadcast_from_worker0(x1), 
relax.TensorStructInfo((2, 3), "int8")
+    )
+    _check_inference(
+        bb, relax.op.ccl.broadcast_from_worker0(x2), 
relax.TensorStructInfo((2, 3), "int64")
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_transform_legalize_ops_ccl.py 
b/tests/python/relax/test_transform_legalize_ops_ccl.py
index ef535bef53..9bce76cecb 100644
--- a/tests/python/relax/test_transform_legalize_ops_ccl.py
+++ b/tests/python/relax/test_transform_legalize_ops_ccl.py
@@ -51,5 +51,26 @@ def test_allreduce():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_broadcast_from_zero():
+    # fmt: off
+    @tvm.script.ir_module
+    class BroadcastFromZero:
+        @R.function
+        def main(x: R.Tensor((10, 10), "float32"))  -> R.Tensor((10, 10), 
"float32"):
+            gv0: R.Tensor((10, 10), "float32") = 
R.ccl.broadcast_from_worker0(x)
+            return x
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), 
dtype="float32"):
+            gv0: R.Tensor((10, 10), dtype="float32") = 
R.call_pure_packed("runtime.disco.broadcast_from_worker0", x, 
sinfo_args=R.Tensor((10, 10), dtype="float32"))
+            return x
+    # fmt: on
+
+    mod = LegalizeOps()(BroadcastFromZero)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to