This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 71b81127a2 [Disco][Op] broadcast_from_worker0 (#15633)
71b81127a2 is described below
commit 71b81127a29255ce9280d5397f64de8d82e4f5fb
Author: Lesheng Jin <[email protected]>
AuthorDate: Mon Aug 28 14:26:51 2023 -0700
[Disco][Op] broadcast_from_worker0 (#15633)
This pr introduces op broadcast_from_zero, which broadcasts input tensor
from worker-0 to all other workers.
---
python/tvm/relax/op/ccl/ccl.py | 18 ++++++
python/tvm/relax/transform/legalize_ops/ccl.py | 9 +++
python/tvm/runtime/disco/session.py | 2 +-
src/relax/op/ccl/ccl.cc | 20 ++++++
src/relax/op/ccl/ccl.h | 3 +
src/runtime/disco/nccl/nccl.cc | 3 +-
tests/python/disco/test_nccl.py | 16 ++---
tests/python/relax/test_op_ccl.py | 75 ++++++++++++++++++++++
.../relax/test_transform_legalize_ops_ccl.py | 21 ++++++
9 files changed, 157 insertions(+), 10 deletions(-)
diff --git a/python/tvm/relax/op/ccl/ccl.py b/python/tvm/relax/op/ccl/ccl.py
index 7fede70543..093e01b638 100644
--- a/python/tvm/relax/op/ccl/ccl.py
+++ b/python/tvm/relax/op/ccl/ccl.py
@@ -17,6 +17,8 @@
"""Relax Collective Communications Library (CCL) operators"""
from . import _ffi_api
+from ...expr import Expr
+
def allreduce(x, op_type: str = "sum"): # pylint: disable=invalid-name
"""Allreduce operator
@@ -40,3 +42,19 @@ def allreduce(x, op_type: str = "sum"): # pylint:
disable=invalid-name
f"including {supported_op_types}, but got {op_type}."
)
return _ffi_api.allreduce(x, op_type) # type: ignore # pylint:
disable=no-member
+
+
+def broadcast_from_worker0(x: Expr) -> Expr:
+ """Broadcast data from worker-0 to all other workers.
+
+ Parameters
+ ----------
+ x : relax.Expr
+ The tensor to be broadcast.
+
+ Returns
+ -------
+ result : relax.Expr
+ The same tensor, which has been broadcast to all other workers.
+ """
+ return _ffi_api.broadcast_from_worker0(x)
diff --git a/python/tvm/relax/transform/legalize_ops/ccl.py
b/python/tvm/relax/transform/legalize_ops/ccl.py
index b1df104518..c9c09952b2 100644
--- a/python/tvm/relax/transform/legalize_ops/ccl.py
+++ b/python/tvm/relax/transform/legalize_ops/ccl.py
@@ -43,3 +43,12 @@ def _allreduce(_bb: BlockBuilder, call: Call) -> Expr:
ShapeExpr([op_type_map[op_type_str]]),
sinfo_args=call.args[0].struct_info,
)
+
+
+@register_legalize("relax.ccl.broadcast_from_worker0")
+def broadcast_from_worker0(_bb: BlockBuilder, call: Call) -> Expr:
+ return call_pure_packed(
+ "runtime.disco.broadcast_from_worker0",
+ call.args[0],
+ sinfo_args=call.args[0].struct_info,
+ )
diff --git a/python/tvm/runtime/disco/session.py
b/python/tvm/runtime/disco/session.py
index e271f72138..587704fe28 100644
--- a/python/tvm/runtime/disco/session.py
+++ b/python/tvm/runtime/disco/session.py
@@ -267,7 +267,7 @@ class Session(Object):
func = self.get_global_func(f"runtime.disco.{api}.init")
func(*args)
- def broadcast_from_worker0(self, array: DRef) -> None:
+ def broadcast_from_worker0(self, array: DRef) -> DRef:
"""Broadcast an array from worker-0 to all other workers.
Parameters
diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc
index 40c897532d..6fa6c96db3 100644
--- a/src/relax/op/ccl/ccl.cc
+++ b/src/relax/op/ccl/ccl.cc
@@ -50,5 +50,25 @@ TVM_REGISTER_OP("relax.ccl.allreduce")
.set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise)
.set_attr<Bool>("FPurity", Bool(true));
+/* relax.ccl.broadcast_from_worker0 */
+Expr broadcast_from_worker0(Expr x) {
+ static const Op& op = Op::Get("relax.ccl.broadcast_from_worker0");
+ return Call(op, {std::move(x)}, {}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.ccl.broadcast_from_worker0").set_body_typed(broadcast_from_worker0);
+
+StructInfo InferStructInfoBroadcastFromZero(const Call& call, const
BlockBuilder& ctx) {
+ TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+ return input_sinfo;
+}
+
+TVM_REGISTER_OP("relax.ccl.broadcast_from_worker0")
+ .set_num_inputs(1)
+ .add_argument("x", "Tensor", "Input to be broadcast.")
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoBroadcastFromZero)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise)
+ .set_attr<Bool>("FPurity", Bool(true));
+
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/ccl/ccl.h b/src/relax/op/ccl/ccl.h
index f87512c138..55402f3d37 100644
--- a/src/relax/op/ccl/ccl.h
+++ b/src/relax/op/ccl/ccl.h
@@ -35,6 +35,9 @@ namespace relax {
/*! \brief AllReduce. */
Expr allreduce(Expr data, String op_type);
+/*! \brief Broadcast data from worker-0 to all other workers. */
+Expr broadcast_from_worker0(Expr data);
+
} // namespace relax
} // namespace tvm
diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc
index f1b6c6c623..48f552c5aa 100644
--- a/src/runtime/disco/nccl/nccl.cc
+++ b/src/runtime/disco/nccl/nccl.cc
@@ -109,13 +109,14 @@ NDArray AllReduce(NDArray send, ReduceKind reduce_kind) {
return recv;
}
-void BroadcastFromZero(NDArray buffer) {
+NDArray BroadcastFromZero(NDArray buffer) {
ShapeTuple shape = buffer.Shape();
int64_t numel = GetNumel(shape);
NCCL_CALL(ncclBroadcast(buffer->data, buffer->data, numel,
/*datatype=*/AsNCCLDataType(DataType(buffer->dtype)), //
/*root=*/0,
NCCLGlobalContext::ThreadLocalCommunicator(),
NCCLGlobalContext::ThreadLocalStream()));
+ return buffer;
}
TVM_REGISTER_GLOBAL("runtime.disco.nccl.init").set_body([](TVMArgs args,
TVMRetValue* rv) -> void {
diff --git a/tests/python/disco/test_nccl.py b/tests/python/disco/test_nccl.py
index 71f13859ab..af7da28db2 100644
--- a/tests/python/disco/test_nccl.py
+++ b/tests/python/disco/test_nccl.py
@@ -60,7 +60,7 @@ def test_allreduce():
np.testing.assert_equal(result, expected)
-def test_broadcast_from_zero():
+def test_broadcast_from_worker0():
num_workers = 2
devices = [1, 2]
array = np.arange(12, dtype="float32").reshape(3, 4)
@@ -105,7 +105,8 @@ def test_mlp(): # pylint: disable=too-many-locals
) -> R.Tensor((128, 128), "float32"):
R.func_attr({"global_symbol": "main"})
with R.dataflow():
- lv0: R.Tensor((128, 64), "float32") = R.matmul(x, W1)
+ broadcast_x: R.Tensor((128, 128), "float32") =
R.ccl.broadcast_from_worker0(x)
+ lv0: R.Tensor((128, 64), "float32") = R.matmul(broadcast_x, W1)
lv1: R.Tensor((128, 64), "float32") = R.nn.gelu(lv0)
lv2: R.Tensor((128, 128), "float32") = R.matmul(lv1, W2)
lv3: R.Tensor((128, 128), "float32") = R.ccl.allreduce(lv2,
"sum")
@@ -159,7 +160,6 @@ def test_mlp(): # pylint: disable=too-many-locals
d_W2 = sess.empty((64, 128), "float32")
d_X.debug_copy_from(0, X)
- d_X.debug_copy_from(1, X)
d_W1.debug_copy_from(0, W1[:, :64])
d_W1.debug_copy_from(1, W1[:, 64:])
d_W2.debug_copy_from(0, W2[:64, :])
@@ -230,16 +230,17 @@ def test_attention(): # pylint: disable=too-many-locals
) -> R.Tensor((128, 128), "float32"):
R.func_attr({"global_symbol": "main"})
with R.dataflow():
+ broadcast_x: R.Tensor((1, 10, 128), "float32") =
R.ccl.broadcast_from_worker0(x)
# q
- lv0: R.Tensor((1, 10, 256), "float32") = R.matmul(x, Wq)
+ lv0: R.Tensor((1, 10, 256), "float32") = R.matmul(broadcast_x,
Wq)
lv1: R.Tensor((1, 10, 4, 64), "float32") = R.reshape(lv0, [1,
10, 4, 64])
lv2: R.Tensor((1, 4, 10, 64), "float32") = R.permute_dims(lv1,
[0, 2, 1, 3])
# k
- lv3: R.Tensor((1, 10, 256), "float32") = R.matmul(x, Wk)
+ lv3: R.Tensor((1, 10, 256), "float32") = R.matmul(broadcast_x,
Wk)
lv4: R.Tensor((1, 10, 4, 64), "float32") = R.reshape(lv3, [1,
10, 4, 64])
lv5: R.Tensor((1, 4, 10, 64), "float32") = R.permute_dims(lv4,
[0, 2, 1, 3])
# v
- lv6: R.Tensor((1, 10, 256), "float32") = R.matmul(x, Wv)
+ lv6: R.Tensor((1, 10, 256), "float32") = R.matmul(broadcast_x,
Wv)
lv7: R.Tensor((1, 10, 4, 64), "float32") = R.reshape(lv6, [1,
10, 4, 64])
lv8: R.Tensor((1, 4, 10, 64), "float32") = R.permute_dims(lv7,
[0, 2, 1, 3])
# softmax(q @ k / sqrt(dk))
@@ -312,7 +313,6 @@ def test_attention(): # pylint: disable=too-many-locals
d_Wo = sess.empty((256, 128), "float32")
d_X.debug_copy_from(0, X)
- d_X.debug_copy_from(1, X)
d_Wq.debug_copy_from(0, Wq[:, :256])
d_Wq.debug_copy_from(1, Wq[:, 256:])
d_Wk.debug_copy_from(0, Wk[:, :256])
@@ -332,7 +332,7 @@ def test_attention(): # pylint: disable=too-many-locals
if __name__ == "__main__":
test_init()
- test_broadcast_from_zero()
+ test_broadcast_from_worker0()
test_allreduce()
test_mlp()
test_attention()
diff --git a/tests/python/relax/test_op_ccl.py
b/tests/python/relax/test_op_ccl.py
index fd25b393cb..09924d27ec 100644
--- a/tests/python/relax/test_op_ccl.py
+++ b/tests/python/relax/test_op_ccl.py
@@ -26,6 +26,7 @@ from tvm.script import relax as R
def test_op_correctness():
x = relax.Var("x", R.Tensor((2, 3), "float32"))
assert relax.op.ccl.allreduce(x).op == Op.get("relax.ccl.allreduce")
+ assert relax.op.ccl.broadcast_from_worker0(x).op ==
Op.get("relax.ccl.broadcast_from_worker0")
def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo:
relax.StructInfo):
@@ -85,5 +86,79 @@ def test_allreduce_infer_struct_info_more_input_dtype():
_check_inference(bb, relax.op.ccl.allreduce(x2),
relax.TensorStructInfo((2, 3), "int64"))
+def test_broadcast_from_worker0_infer_struct_info():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
+ x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+ x2 = relax.Var("x", R.Tensor("float32", ndim=-1))
+ x3 = relax.Var("x", R.Tensor((2, 3)))
+ x4 = relax.Var("x", R.Tensor())
+ x5 = relax.Var("x", R.Tensor((3, 4)))
+
+ _check_inference(
+ bb, relax.op.ccl.broadcast_from_worker0(x0),
relax.TensorStructInfo((2, 3), "float32")
+ )
+ _check_inference(
+ bb, relax.op.ccl.broadcast_from_worker0(x1),
relax.TensorStructInfo(dtype="float32", ndim=3)
+ )
+ _check_inference(
+ bb, relax.op.ccl.broadcast_from_worker0(x2),
relax.TensorStructInfo(dtype="float32")
+ )
+ _check_inference(
+ bb, relax.op.ccl.broadcast_from_worker0(x3),
relax.TensorStructInfo((2, 3), dtype="")
+ )
+ _check_inference(bb, relax.op.ccl.broadcast_from_worker0(x4),
relax.TensorStructInfo(dtype=""))
+ _check_inference(
+ bb, relax.op.ccl.broadcast_from_worker0(x5),
relax.TensorStructInfo((3, 4), dtype="")
+ )
+
+
+def test_broadcast_from_worker0_infer_struct_info_shape_symbolic():
+ bb = relax.BlockBuilder()
+ m = tir.Var("m", "int64")
+ n = tir.Var("n", "int64")
+ x0 = relax.Var("x", R.Tensor((m, n), "float32"))
+ x1 = relax.Var("x", R.Tensor((4, n), "float32"))
+
+ _check_inference(
+ bb, relax.op.ccl.broadcast_from_worker0(x0),
relax.TensorStructInfo((m, n), "float32")
+ )
+ _check_inference(
+ bb, relax.op.ccl.broadcast_from_worker0(x1),
relax.TensorStructInfo((4, n), "float32")
+ )
+
+
+def test_broadcast_from_worker0_infer_struct_info_shape_var():
+ bb = relax.BlockBuilder()
+ s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2))
+ s1 = relax.Var("s", relax.ShapeStructInfo())
+ x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+ x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+
+ _check_inference(
+ bb, relax.op.ccl.broadcast_from_worker0(x0),
relax.TensorStructInfo(s0, "float32")
+ )
+ _check_inference(
+ bb, relax.op.ccl.broadcast_from_worker0(x1),
relax.TensorStructInfo(s1, "float32")
+ )
+
+
+def test_broadcast_from_worker0_infer_struct_info_more_input_dtype():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3), "float64"))
+ x1 = relax.Var("x", R.Tensor((2, 3), "int8"))
+ x2 = relax.Var("x", R.Tensor((2, 3), "int64"))
+
+ _check_inference(
+ bb, relax.op.ccl.broadcast_from_worker0(x0),
relax.TensorStructInfo((2, 3), "float64")
+ )
+ _check_inference(
+ bb, relax.op.ccl.broadcast_from_worker0(x1),
relax.TensorStructInfo((2, 3), "int8")
+ )
+ _check_inference(
+ bb, relax.op.ccl.broadcast_from_worker0(x2),
relax.TensorStructInfo((2, 3), "int64")
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_transform_legalize_ops_ccl.py
b/tests/python/relax/test_transform_legalize_ops_ccl.py
index ef535bef53..9bce76cecb 100644
--- a/tests/python/relax/test_transform_legalize_ops_ccl.py
+++ b/tests/python/relax/test_transform_legalize_ops_ccl.py
@@ -51,5 +51,26 @@ def test_allreduce():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_broadcast_from_zero():
+ # fmt: off
+ @tvm.script.ir_module
+ class BroadcastFromZero:
+ @R.function
+ def main(x: R.Tensor((10, 10), "float32")) -> R.Tensor((10, 10),
"float32"):
+ gv0: R.Tensor((10, 10), "float32") =
R.ccl.broadcast_from_worker0(x)
+ return x
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10),
dtype="float32"):
+ gv0: R.Tensor((10, 10), dtype="float32") =
R.call_pure_packed("runtime.disco.broadcast_from_worker0", x,
sinfo_args=R.Tensor((10, 10), dtype="float32"))
+ return x
+ # fmt: on
+
+ mod = LegalizeOps()(BroadcastFromZero)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()