MasterJH5574 opened a new pull request, #15373:
URL: https://github.com/apache/tvm/pull/15373

   PR #15327 introduces the warp-level primitive support in multi-warp 
allreduce. However, due to the specialty of the two-stage shuffle-down 
reduction implementation of the allreduce in multi-warp scenarios, PR #15327 
did not broadcast the allreduce result to each reduction thread. This behavior 
does not align with the semantics of allreduce and is not ideal for many use 
cases. Therefore, this PR completes the implementation by inserting a stage of 
writing the reduction results to shared memory, so that each reduction thread 
across all the reduction warps can access the reduction results.
   
   This shared memory write-back stage will only be inserted in multi-warp 
allreduce cases. In single-warp allreduce, a `shfl_sync` is used to broadcast 
the reduction results across reduction threads. Since in multi-warp settings we 
cannot leverage warp-level primitives to broadcast the value, we can only make 
use of shared memory.
   
   The numerical correctness are verified locally.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to