Lunderberg commented on code in PR #15764:
URL: https://github.com/apache/tvm/pull/15764#discussion_r1332182751
##########
src/relax/op/ccl/ccl.cc:
##########
@@ -50,6 +50,44 @@ TVM_REGISTER_OP("relax.ccl.allreduce")
.set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise)
.set_attr<Bool>("FPurity", Bool(true));
+/* relax.ccl.allgather */
+TVM_REGISTER_NODE_TYPE(AllGatherAttrs);
+
+Expr allgather(Expr x, int num_workers) {
Review Comment:
Here, this would be `Expr allgather(Expr x, Expr num_workers)`.
##########
src/runtime/disco/builtin.h:
##########
@@ -52,6 +52,12 @@ NDArray DiscoEmptyNDArray(ShapeTuple shape, DataType dtype,
Device device);
* \return The outcome of allreduce
*/
void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv);
+/*!
+ * \brief Perform an allgather operation using the underlying communication
library
+ * \param send The array send to perform allgather on
+ * \return The outcome of allgather
+ */
+void AllGather(NDArray send, Optional<NDArray> recv);
Review Comment:
The signature here doesn't match the signature of the implementation. In
the implementation, the second parameter is `NDArray recv` and not
`Optional<NDArray> recv`.
##########
src/relax/op/ccl/ccl.cc:
##########
@@ -50,6 +50,44 @@ TVM_REGISTER_OP("relax.ccl.allreduce")
.set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise)
.set_attr<Bool>("FPurity", Bool(true));
+/* relax.ccl.allgather */
+TVM_REGISTER_NODE_TYPE(AllGatherAttrs);
+
+Expr allgather(Expr x, int num_workers) {
+ ObjectPtr<AllGatherAttrs> attrs = make_object<AllGatherAttrs>();
+ attrs->num_workers = std::move(num_workers);
+ static const Op& op = Op::Get("relax.ccl.allgather");
+ return Call(op, {std::move(x)}, Attrs{attrs}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.ccl.allgather").set_body_typed(allgather);
+
+StructInfo InferStructInfoAllGather(const Call& call, const BlockBuilder& ctx)
{
+ TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+ DataType output_dtype = input_sinfo->dtype;
+
+ const auto* attrs = call->attrs.as<AllGatherAttrs>();
+ int num_workers = attrs->num_workers;
+ auto input_shape = input_sinfo->GetShape();
+ if (!input_shape.defined()) {
+ return input_sinfo;
+ }
+ Array<PrimExpr> output_shape = input_shape.value();
+ output_shape.Set(0, floor(output_shape[0] * num_workers));
+ if (input_sinfo->vdevice.defined()) {
Review Comment:
Instead of duplicating the entire `return TensorStructInfo(...)` call, we
can make a local variable. Otherwise, readers would need to inspect the
```c++
VDevice vdevice;
if(input_sinfo->vdevice.defined()) {
vdevice = input_sinfo->vdevice.value();
}
return TensorStructInfo(ShapeExpr(output_shape), output_dtype, vdevice);
```
##########
python/tvm/relax/op/ccl/ccl.py:
##########
@@ -44,6 +44,25 @@ def allreduce(x, op_type: str = "sum"): # pylint:
disable=invalid-name
return _ffi_api.allreduce(x, op_type) # type: ignore # pylint:
disable=no-member
+def allgather(x, num_workers: int): # pylint: disable=invalid-name
Review Comment:
I like the inclusion of the `num_workers` here, as it allows shape
propagation across the `allgather`. Instead of an integer, can we make it be a
`relax.PrimValue`? That would (1) allow the `num_workers` to be a symbolic
variable for earlier stages of lowering, (2) play nicer with the relax pattern
matcher, and (3) remove the need for the `AllGatherAttrs`.
```python
def allgather(x, num_workers: Union[int, PrimExpr, relax.PrimValue]):
if not isinstance(num_workers, relax.PrimValue):
num_workers = relax.PrimValue(num_workers)
...
```
##########
src/relax/op/ccl/ccl.cc:
##########
@@ -50,6 +50,44 @@ TVM_REGISTER_OP("relax.ccl.allreduce")
.set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise)
.set_attr<Bool>("FPurity", Bool(true));
+/* relax.ccl.allgather */
+TVM_REGISTER_NODE_TYPE(AllGatherAttrs);
+
+Expr allgather(Expr x, int num_workers) {
+ ObjectPtr<AllGatherAttrs> attrs = make_object<AllGatherAttrs>();
+ attrs->num_workers = std::move(num_workers);
+ static const Op& op = Op::Get("relax.ccl.allgather");
+ return Call(op, {std::move(x)}, Attrs{attrs}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.ccl.allgather").set_body_typed(allgather);
+
+StructInfo InferStructInfoAllGather(const Call& call, const BlockBuilder& ctx)
{
+ TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
Review Comment:
With `num_workers` as a `PrimValue`, this would instead be
```c++
CHECK_EQ(call->args.size(), 2);
auto input_sinfo = Downcast<TensorStructInfo>(call->args[0]->struct_info_);
auto num_workers_sinfo =
Downcast<PrimStructInfo>(call->args[1]->struct_info_);
auto num_workers = num_workers_sinfo->value;
```
##########
src/runtime/disco/builtin.cc:
##########
@@ -84,6 +84,8 @@ void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray
recv) {
GetCCLFunc("allreduce")(send, static_cast<int>(reduce_kind), recv);
}
+void AllGathere(NDArray send, NDArray recv) { GetCCLFunc("allgather")(send,
recv); }
Review Comment:
Typo: This should be `AllGather` instead of `Allgathere`.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]