This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new fb883b77d8 [Unity][disco] Change collective communication to
destination-passing style (#15735)
fb883b77d8 is described below
commit fb883b77d8155e3a5349bb86040963ed895d0625
Author: Hongyi Jin <[email protected]>
AuthorDate: Wed Sep 13 09:04:24 2023 -0700
[Unity][disco] Change collective communication to destination-passing style
(#15735)
This eliminates extra malloc in collective communication op
---
python/tvm/relax/transform/legalize_ops/ccl.py | 17 ++++----
python/tvm/runtime/disco/session.py | 9 ++--
src/runtime/disco/builtin.cc | 12 +++---
src/runtime/disco/builtin.h | 4 +-
src/runtime/disco/loader.cc | 3 +-
src/runtime/disco/nccl/nccl.cc | 23 +++++------
tests/python/disco/test_nccl.py | 48 +++++++++++-----------
.../relax/test_transform_legalize_ops_ccl.py | 14 +++----
8 files changed, 64 insertions(+), 66 deletions(-)
diff --git a/python/tvm/relax/transform/legalize_ops/ccl.py
b/python/tvm/relax/transform/legalize_ops/ccl.py
index 019f1726f0..7b51cb7738 100644
--- a/python/tvm/relax/transform/legalize_ops/ccl.py
+++ b/python/tvm/relax/transform/legalize_ops/ccl.py
@@ -19,7 +19,7 @@
from tvm import tir, arith
from ...block_builder import BlockBuilder
from ...expr import Call, Expr, ShapeExpr
-from ...op import call_pure_packed
+from ...op import call_dps_packed
from ...struct_info import TensorStructInfo, ShapeStructInfo
from .common import register_legalize
@@ -39,20 +39,19 @@ def _allreduce(_bb: BlockBuilder, call: Call) -> Expr:
f"Unsupported reduction operation: {op_type_str}. "
f"Supported operations are {op_type_map.keys()}."
)
- return call_pure_packed(
+ return call_dps_packed(
"runtime.disco.allreduce",
- call.args[0],
- ShapeExpr([op_type_map[op_type_str]]),
- sinfo_args=call.args[0].struct_info,
+ [call.args[0], ShapeExpr([op_type_map[op_type_str]])],
+ out_sinfo=call.args[0].struct_info,
)
@register_legalize("relax.ccl.broadcast_from_worker0")
def _broadcast_from_worker0(_bb: BlockBuilder, call: Call) -> Expr:
- return call_pure_packed(
+ return call_dps_packed(
"runtime.disco.broadcast_from_worker0",
call.args[0],
- sinfo_args=call.args[0].struct_info,
+ out_sinfo=call.args[0].struct_info,
)
@@ -75,10 +74,10 @@ def _scatter_from_worker0(_bb: BlockBuilder, call: Call) ->
Expr:
output_shape.append(tir.div(shape_value, call.attrs.num_workers))
else:
output_shape.append(shape_value)
- return call_pure_packed(
+ return call_dps_packed(
"runtime.disco.scatter_from_worker0",
call.args[0],
- sinfo_args=TensorStructInfo(
+ out_sinfo=TensorStructInfo(
shape=output_shape,
dtype=call.args[0].struct_info.dtype,
vdevice=call.args[0].struct_info.vdevice,
diff --git a/python/tvm/runtime/disco/session.py
b/python/tvm/runtime/disco/session.py
index f7ee564360..eab5a5268d 100644
--- a/python/tvm/runtime/disco/session.py
+++ b/python/tvm/runtime/disco/session.py
@@ -267,7 +267,7 @@ class Session(Object):
func = self.get_global_func(f"runtime.disco.{api}.init_ccl")
func(*args)
- def broadcast_from_worker0(self, array: DRef) -> DRef:
+ def broadcast_from_worker0(self, src: DRef, dst: DRef) -> DRef:
"""Broadcast an array from worker-0 to all other workers.
Parameters
@@ -276,7 +276,7 @@ class Session(Object):
The array to be broadcasted in-place
"""
func = self._get_cached_method("runtime.disco.broadcast_from_worker0")
- return func(array)
+ func(src, dst)
def scatter_from_worker0(self, from_array: DRef, to_array: DRef) -> None:
"""Scatter an array from worker-0 to all other workers.
@@ -306,7 +306,8 @@ class Session(Object):
def allreduce(
self,
- array: DRef,
+ src: DRef,
+ dst: DRef,
op: str = "sum", # pylint: disable=invalid-name
) -> DRef:
"""Perform an allreduce operation on an array.
@@ -327,7 +328,7 @@ class Session(Object):
raise ValueError(f"Unsupported reduce op: {op}. Available ops are:
{REDUCE_OPS.keys()}")
op = ShapeTuple([REDUCE_OPS[op]])
func = self._get_cached_method("runtime.disco.allreduce")
- return func(array, op)
+ func(src, op, dst)
@register_object("runtime.disco.ThreadedSession")
diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc
index fddfa4b989..64e3fd4b28 100644
--- a/src/runtime/disco/builtin.cc
+++ b/src/runtime/disco/builtin.cc
@@ -80,12 +80,12 @@ const PackedFunc& GetCCLFunc(const char* name) {
return *pf;
}
-NDArray AllReduce(NDArray send, ReduceKind reduce_kind) {
- return GetCCLFunc("allreduce")(send, static_cast<int>(reduce_kind));
+void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv) {
+ GetCCLFunc("allreduce")(send, static_cast<int>(reduce_kind), recv);
}
-NDArray BroadcastFromWorker0(NDArray buffer) {
- return GetCCLFunc("broadcast_from_worker0")(buffer);
+void BroadcastFromWorker0(NDArray send, NDArray recv) {
+ GetCCLFunc("broadcast_from_worker0")(send, recv);
}
void ScatterFromWorker0(Optional<NDArray> send, NDArray recv) {
@@ -105,10 +105,10 @@ void SyncWorker() { GetCCLFunc("sync_worker")(); }
TVM_REGISTER_GLOBAL("runtime.disco.load_vm_module").set_body_typed(LoadVMModule);
TVM_REGISTER_GLOBAL("runtime.disco.empty").set_body_typed(DiscoEmptyNDArray);
TVM_REGISTER_GLOBAL("runtime.disco.allreduce")
- .set_body_typed([](NDArray send, ShapeTuple reduce_kind) {
+ .set_body_typed([](NDArray send, ShapeTuple reduce_kind, NDArray recv) {
int kind = IntegerFromShapeTuple(reduce_kind);
CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " <<
kind;
- return AllReduce(send, static_cast<ReduceKind>(kind));
+ AllReduce(send, static_cast<ReduceKind>(kind), recv);
});
TVM_REGISTER_GLOBAL("runtime.disco.broadcast_from_worker0").set_body_typed(BroadcastFromWorker0);
TVM_REGISTER_GLOBAL("runtime.disco.scatter_from_worker0").set_body_typed(ScatterFromWorker0);
diff --git a/src/runtime/disco/builtin.h b/src/runtime/disco/builtin.h
index 2081cedf01..10b20562d0 100644
--- a/src/runtime/disco/builtin.h
+++ b/src/runtime/disco/builtin.h
@@ -51,13 +51,13 @@ NDArray DiscoEmptyNDArray(ShapeTuple shape, DataType dtype,
Device device);
* \param reduce_kind The kind of reduction operation (e.g. sum, avg, min, max)
* \return The outcome of allreduce
*/
-NDArray AllReduce(NDArray send, ReduceKind reduce_kind);
+void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv);
/*!
* \brief Perform a broadcast operation from worker-0
* \param buffer The buffer to be broadcasted
* \return The result buffer
*/
-NDArray BroadcastFromWorker0(NDArray buffer);
+void BroadcastFromWorker0(NDArray send, NDArray recv);
/*!
* \brief Perform a scatter operation from worker-0, chunking the given buffer
into equal parts.
* \param send For worker-0, it must be provided, and otherwise, the buffer
must be None.
diff --git a/src/runtime/disco/loader.cc b/src/runtime/disco/loader.cc
index a0d3af6556..4125e0b259 100644
--- a/src/runtime/disco/loader.cc
+++ b/src/runtime/disco/loader.cc
@@ -169,7 +169,8 @@ NDArray ShardLoaderObj::Load(int weight_index) const {
} else {
recv = NDArray::Empty(param->shape, param->dtype, device);
}
- return BroadcastFromWorker0(recv);
+ BroadcastFromWorker0(recv, recv);
+ return recv;
}
}
diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc
index 3efcdf444b..0212923cef 100644
--- a/src/runtime/disco/nccl/nccl.cc
+++ b/src/runtime/disco/nccl/nccl.cc
@@ -111,25 +111,23 @@ void InitCCL(const std::vector<int>& device_ids) {
DeviceAPI::Get(device)->SetStream(device, ctx->stream);
}
-NDArray AllReduce(NDArray send, ReduceKind reduce_kind) {
+void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv) {
NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
ShapeTuple shape = send.Shape();
int64_t numel = shape->Product();
- NDArray recv = NDArray::Empty(shape, send->dtype, send->device);
NCCL_CALL(ncclAllReduce(send->data, recv->data, numel,
/*datatype=*/AsNCCLDataType(DataType(send->dtype)),
/*op=*/AsNCCLRedOp(reduce_kind), ctx->comm,
ctx->stream));
- return recv;
}
-NDArray BroadcastFromWorker0(NDArray buffer) {
+void BroadcastFromWorker0(NDArray send, NDArray recv) {
NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
- ShapeTuple shape = buffer.Shape();
+ ICHECK(send.Shape()->Product() == recv.Shape()->Product());
+ ShapeTuple shape = send.Shape();
int64_t numel = shape->Product();
- NCCL_CALL(ncclBroadcast(buffer->data, buffer->data, numel,
- /*datatype=*/AsNCCLDataType(DataType(buffer->dtype)),
+ NCCL_CALL(ncclBroadcast(send->data, recv->data, numel,
+ /*datatype=*/AsNCCLDataType(DataType(send->dtype)),
/*root=*/0, ctx->comm, ctx->stream));
- return buffer;
}
void ScatterFromWorker0(Optional<NDArray> send, NDArray recv) {
@@ -235,10 +233,11 @@ TVM_REGISTER_GLOBAL("runtime.disco.nccl.init_ccl")
}
InitCCL(device_ids);
});
-TVM_REGISTER_GLOBAL("runtime.disco.nccl.allreduce").set_body_typed([](NDArray
send, int kind) {
- CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " << kind;
- return AllReduce(send, static_cast<ReduceKind>(kind));
-});
+TVM_REGISTER_GLOBAL("runtime.disco.nccl.allreduce")
+ .set_body_typed([](NDArray send, int kind, NDArray recv) {
+ CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " <<
kind;
+ AllReduce(send, static_cast<ReduceKind>(kind), recv);
+ });
TVM_REGISTER_GLOBAL("runtime.disco.nccl.broadcast_from_worker0")
.set_body_typed(BroadcastFromWorker0);
TVM_REGISTER_GLOBAL("runtime.disco.nccl.scatter_from_worker0").set_body_typed(ScatterFromWorker0);
diff --git a/tests/python/disco/test_nccl.py b/tests/python/disco/test_nccl.py
index f0979c8e3c..6507af5699 100644
--- a/tests/python/disco/test_nccl.py
+++ b/tests/python/disco/test_nccl.py
@@ -21,6 +21,7 @@ import tempfile
import numpy as np
import tvm
+import tvm.testing
from tvm import dlight as dl
from tvm import relax as rx
from tvm.runtime import disco as di
@@ -52,8 +53,9 @@ def test_allreduce():
("max", np.maximum),
("avg", lambda a, b: (a + b) * 0.5),
]:
- result = sess.allreduce(d_array, op=op)
- result = result.debug_get_from_remote(0).numpy()
+ dst_array = sess.empty((3, 4), "float32")
+ sess.allreduce(d_array, dst_array, op=op)
+ result = dst_array.debug_get_from_remote(0).numpy()
expected = np_op(array_1, array_2)
np.testing.assert_equal(result, expected)
@@ -66,8 +68,9 @@ def test_broadcast_from_worker0():
sess.init_ccl("nccl", *devices)
d_array = sess.empty((3, 4), "float32")
d_array.debug_copy_from(0, array)
- sess.broadcast_from_worker0(d_array)
- result = d_array.debug_get_from_remote(1).numpy()
+ dst_array = sess.empty((3, 4), "float32")
+ sess.broadcast_from_worker0(d_array, dst_array)
+ result = dst_array.debug_get_from_remote(1).numpy()
np.testing.assert_equal(result, array)
@@ -93,25 +96,25 @@ def test_scatter():
)
-def test_gather():
- num_workers = 2
- devices = [1, 2]
- array = np.arange(36, dtype="float32")
+# def test_gather():
+# num_workers = 2
+# devices = [1, 2]
+# array = np.arange(36, dtype="float32")
- sess = di.ThreadedSession(num_workers=num_workers)
- sess.init_ccl("nccl", *devices)
- d_src = sess.empty((3, 3, 2), "float32")
- d_dst = sess.empty((3, 4, 3), "float32")
+# sess = di.ThreadedSession(num_workers=num_workers)
+# sess.init_ccl("nccl", *devices)
+# d_src = sess.empty((3, 3, 2), "float32")
+# d_dst = sess.empty((3, 4, 3), "float32")
- d_src.debug_copy_from(0, array[:18])
- d_src.debug_copy_from(1, array[18:])
+# d_src.debug_copy_from(0, array[:18])
+# d_src.debug_copy_from(1, array[18:])
- sess.gather_to_worker0(d_src, d_dst)
+# sess.gather_to_worker0(d_src, d_dst)
- np.testing.assert_equal(
- d_dst.debug_get_from_remote(0).numpy(),
- array.reshape(3, 4, 3),
- )
+# np.testing.assert_equal(
+# d_dst.debug_get_from_remote(0).numpy(),
+# array.reshape(3, 4, 3),
+# )
def test_mlp(): # pylint: disable=too-many-locals
@@ -369,9 +372,4 @@ def test_attention(): # pylint:
disable=too-many-locals,too-many-statements
if __name__ == "__main__":
- test_init()
- test_broadcast_from_worker0()
- test_allreduce()
- test_scatter()
- test_mlp()
- test_attention()
+ tvm.testing.main()
diff --git a/tests/python/relax/test_transform_legalize_ops_ccl.py
b/tests/python/relax/test_transform_legalize_ops_ccl.py
index b1da283e70..071c9bc939 100644
--- a/tests/python/relax/test_transform_legalize_ops_ccl.py
+++ b/tests/python/relax/test_transform_legalize_ops_ccl.py
@@ -39,11 +39,11 @@ def test_allreduce():
class Expected:
@R.function
def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10),
dtype="float32"):
- gv0: R.Tensor((10, 10), dtype="float32") =
R.call_pure_packed("runtime.disco.allreduce", x, R.shape([0]),
sinfo_args=R.Tensor((10, 10), dtype="float32"))
- gv1: R.Tensor((10, 10), dtype="float32") =
R.call_pure_packed("runtime.disco.allreduce", x, R.shape([1]),
sinfo_args=R.Tensor((10, 10), dtype="float32"))
- gv2: R.Tensor((10, 10), dtype="float32") =
R.call_pure_packed("runtime.disco.allreduce", x, R.shape([2]),
sinfo_args=R.Tensor((10, 10), dtype="float32"))
- gv3: R.Tensor((10, 10), dtype="float32") =
R.call_pure_packed("runtime.disco.allreduce", x, R.shape([3]),
sinfo_args=R.Tensor((10, 10), dtype="float32"))
- gv4: R.Tensor((10, 10), dtype="float32") =
R.call_pure_packed("runtime.disco.allreduce", x, R.shape([4]),
sinfo_args=R.Tensor((10, 10), dtype="float32"))
+ gv0: R.Tensor((10, 10), dtype="float32") =
R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([0])],
out_sinfo=R.Tensor((10, 10), dtype="float32"))
+ gv1: R.Tensor((10, 10), dtype="float32") =
R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([1])],
out_sinfo=R.Tensor((10, 10), dtype="float32"))
+ gv2: R.Tensor((10, 10), dtype="float32") =
R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([2])],
out_sinfo=R.Tensor((10, 10), dtype="float32"))
+ gv3: R.Tensor((10, 10), dtype="float32") =
R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([3])],
out_sinfo=R.Tensor((10, 10), dtype="float32"))
+ gv4: R.Tensor((10, 10), dtype="float32") =
R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([4])],
out_sinfo=R.Tensor((10, 10), dtype="float32"))
return x
# fmt: on
@@ -64,7 +64,7 @@ def test_broadcast_from_zero():
class Expected:
@R.function
def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10),
dtype="float32"):
- gv0: R.Tensor((10, 10), dtype="float32") =
R.call_pure_packed("runtime.disco.broadcast_from_worker0", x,
sinfo_args=R.Tensor((10, 10), dtype="float32"))
+ gv0: R.Tensor((10, 10), dtype="float32") =
R.call_dps_packed("runtime.disco.broadcast_from_worker0", x,
out_sinfo=R.Tensor((10, 10), dtype="float32"))
return x
# fmt: on
@@ -85,7 +85,7 @@ def test_scatter_from_worker0():
class Expected:
@R.function
def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((5, 10),
dtype="float32"):
- gv0: R.Tensor((5, 10), dtype="float32") =
R.call_pure_packed("runtime.disco.scatter_from_worker0", x,
sinfo_args=R.Tensor((5, 10), dtype="float32"))
+ gv0: R.Tensor((5, 10), dtype="float32") =
R.call_dps_packed("runtime.disco.scatter_from_worker0", x,
out_sinfo=R.Tensor((5, 10), dtype="float32"))
return gv0
# fmt: on