jinhongyii opened a new pull request, #16092:
URL: https://github.com/apache/tvm/pull/16092
This PR adds an `axis` field to scatter_from_worker0, which means the tensor
axis along which it is scattered. legalize_ops will automatically generate
reshape and transpose to preserve the constraint of ccl that collective
communication ops must be performed on contiguous memory. For example, if the
tensor shape of x is [10, 20], and we have `scatter_from_worker0(x,
num_workers=2, axis=1)`, then after legalization it will expand to
```
x = reshape(x, [10, 2, 10]) # shape: [10, 2, 10]
x = permute_dims(x, [1, 0, 2]) # shape: [2, 10, 10]
x = call_dps_packed("scatter_from_worker0", x) # shape: [10, 10]
```
When axis=0, the behavior is the same as before.
Also, this PR renames ScatterFromWorker0Attrs to ScatterAttrs to enable
reuse by other ops like worker-id-aware slicing (scatter_from_worker0 =
broadcast_from_worker0 + worker-id-aware slicing).
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]