MasterJH5574 opened a new pull request, #15373: URL: https://github.com/apache/tvm/pull/15373
PR #15327 introduces the warp-level primitive support in multi-warp allreduce. However, due to the specialty of the two-stage shuffle-down reduction implementation of the allreduce in multi-warp scenarios, PR #15327 did not broadcast the allreduce result to each reduction thread. This behavior does not align with the semantics of allreduce and is not ideal for many use cases. Therefore, this PR completes the implementation by inserting a stage of writing the reduction results to shared memory, so that each reduction thread across all the reduction warps can access the reduction results. This shared memory write-back stage will only be inserted in multi-warp allreduce cases. In single-warp allreduce, a `shfl_sync` is used to broadcast the reduction results across reduction threads. Since in multi-warp settings we cannot leverage warp-level primitives to broadcast the value, we can only make use of shared memory. The numerical correctness are verified locally. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
