Lunderberg commented on code in PR #16098:
URL: https://github.com/apache/tvm/pull/16098#discussion_r1391679402
##########
python/tvm/relax/op/distributed/distributed.py:
##########
@@ -59,3 +59,28 @@ def redistribute(input: Expr, device_mesh: DeviceMesh,
placement: Placement) ->
The tensor after redistribution.
"""
return _ffi_api.redistribute(input, device_mesh, placement) # type: ignore
+
+
+def redistribute_replica_to_shard(input: Expr, num_workers: int, axis: int) ->
Expr:
Review Comment:
One example is if you are comparing how performance scales with the number
of GPUs, then you need a way to specify the number of GPUs. Each data point
collected would be for a specialized value of `num_workers`.
Effectively, having `num_workers: Expr` de-couples the communication
mechanics from the choice of how many workers to use. It still can be
explicitly specified at any stage of lowering, and it still must be
statically-known after lowering, but we gain flexibility before then. This has
a number of benefits, for example:
* Predicted memory usage. If the number of workers is stored symbolically,
adding up the size of all live values at any point gives the memory footprint
as a function of the number of workers. Requiring the number of workers to be
static at all points of lowering prevents this analysis.
* Consistent optimization. If an optimization is applicable regardless of
the number of workers, the optimization should be applied at a point when the
number of workers is unknown. This prevents a developer from accidentally
making a less general optimization. (e.g. By using a sharded tensor shape in
the pattern-matching.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]