Lunderberg commented on code in PR #16098:
URL: https://github.com/apache/tvm/pull/16098#discussion_r1396292028
##########
python/tvm/relax/op/distributed/distributed.py:
##########
@@ -59,3 +59,28 @@ def redistribute(input: Expr, device_mesh: DeviceMesh,
placement: Placement) ->
The tensor after redistribution.
"""
return _ffi_api.redistribute(input, device_mesh, placement) # type: ignore
+
+
+def redistribute_replica_to_shard(input: Expr, num_workers: int, axis: int) ->
Expr:
Review Comment:
The main difference occurs when there are additional optimization steps that
occur after a module definition but before the module is handed off to
`relax.build` for lowering/compilation.
If sharding (and propagation of sharding) is done early in optimization,
then propagation of sharding produces large portions of the compute graph where
no communication steps are required. These portions can be optimized as if
they were single-GPU modules, with no specific handling of multi-GPU setups
required.
If we can write modules with a dynamic number of gpus, we can write the
lowering steps as a single optimization pipeline.
```python
# With specialization occurring late in the pipeline.
mod = Sequential([
pre_sharding_optimizations,
shard_across_multiple_gpus,
propagate_sharding,
convert_to_local_view,
single_gpu_optimizations,
])(mod)
built_modules = [relax.build(specialize(mod, num_gpus)) for num_gpus in
num_gpu_list]
```
If we can only write modules with a static number of gpus, we cannot write
an optimization pipeline, as the optimization pipeline
```python
# With specialization occurring at the start of the pipeline.
mod = pre_sharding_optimizations(mod)
mods = [shard_across_multiple_gpus(mod, num_gpus) for num_gpus in
num_gpu_list]
pipeline = Sequential([
propagate_sharding,
convert_to_local_view,
single_gpu_optimizations,
])
mods = [pipeline(mod) for mod in mods]
built_modules = [relax.build(mod) for mod in mods]
```
It's not that it's impossible by any means, but that the restricted
expressability in an early step means that a user must leave the world of a
single `IRModule` much earlier.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]