jinhongyii opened a new pull request, #16092:
URL: https://github.com/apache/tvm/pull/16092

   This PR adds an `axis` field to scatter_from_worker0, which means the tensor 
axis along which it is scattered. legalize_ops will automatically generate 
reshape and transpose to preserve the constraint of ccl that collective 
communication ops must be performed on contiguous memory. For example, if the 
tensor shape of x is [10, 20], and we have `scatter_from_worker0(x, 
num_workers=2, axis=1)`, then after legalization it will expand to 
   
   ```
   x = reshape(x, [10, 2, 10]) # shape: [10, 2, 10]
   x = permute_dims(x, [1, 0, 2]) # shape: [2, 10, 10]
   x = call_dps_packed("scatter_from_worker0", x) # shape: [10, 10]
   ```
   When axis=0, the behavior is the same as before.
   
   
   Also, this PR renames ScatterFromWorker0Attrs to ScatterAttrs to enable 
reuse by other ops like worker-id-aware slicing (scatter_from_worker0 = 
broadcast_from_worker0 +  worker-id-aware slicing).
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to