Lunderberg commented on code in PR #16094:
URL: https://github.com/apache/tvm/pull/16094#discussion_r1393282812
##########
tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py:
##########
@@ -1060,16 +1738,19 @@ def test_mlp_pipeline_parallelism():
def test_decoder_layer():
- # mod = relax.transform.LegalizeOps({"relax.reshape": lambda bb, call:
bb.normalize(call)})(LlamaAttentionLayer)
- mod = LlamaAttentionLayer
- after = relax.distributed.transform.PropagateSharding()(mod)
+ after =
relax.distributed.transform.PropagateSharding()(LlamaAttentionLayer)
assert_structural_equal(after, ShardedLlamaAttentionLayer)
-def test_decoder_layer_dynamic_shape():
- # mod = relax.transform.LegalizeOps({"relax.reshape": lambda bb, call:
bb.normalize(call)})(LlamaAttentionLayer)
- mod = LlamaAttentionLayerDynamicShape
+def test_decoder_layer_tir():
+ mod = relax.transform.LegalizeOps()(LlamaAttentionLayer)
Review Comment:
> It's an interesting topic to allow the split axis to be dynamic, but
currently we just enforce that the axis has to be static for simplicity.
~~I was thinking that it would be the other way around. In order to
correctly split a `PrimFunc`, the function must contain enough information to
produce a smaller version of itself. That would either require the split axis
to be dynamic, so size changes are handled automatically, or to undo any
constant-folding that has been done using the split size.~~
Edit: Disregard that. Any function could always be split up into a PrimFunc
for each GPU, along with one to merge the final results. Whether those
individual PrimFuncs are identical for each GPU, and whether the merging could
be expressed as an allgather or an allreduce would depend on the function.
> How do I construct ShardedLlamaAttentionLayerTIR by not writing all the
TIR primfuncs?
The current way to do so is by replacing the `cls.FUNC_NAME` with
`LlamaAttentionLayer.get_global_var("FUNC_NAME")` within the TVMScript, then
copying the PrimFunc to it afterwards (example below). Most of the time, the
extra steps make it not worth the effort, but here I think it is worthwhile to
show that every `PrimFunc` is expected to be identical.
```python
def test_function():
@tvm.script.ir_module
class DefinedInTVMScript:
@R.function
def main(
x: R.Tensor((2, 128, 28), dtype="float32"), w: R.Tensor((64, 16,
3), dtype="float32")
) -> R.Tensor((2, 64, 13), dtype="float32"):
gv = R.call_tir(Before.conv1d, (x, w), out_sinfo=R.Tensor((2,
64, 13), dtype="float32"))
return gv
@T.prim_func(private=True)
def conv1d(
A: T.Buffer((T.int64(2), T.int64(128), T.int64(28)), "float32"),
B: T.Buffer((T.int64(64), T.int64(16), T.int64(3)), "float32"),
group_conv1d_ncw: T.Buffer((T.int64(2), T.int64(64),
T.int64(13)), "float32"),
):
pass
@tvm.script.ir_module
class AddedAfterward:
@R.function
def main(
x: R.Tensor((2, 128, 28), dtype="float32"), w: R.Tensor((64, 16,
3), dtype="float32")
) -> R.Tensor((2, 64, 13), dtype="float32"):
gv = R.call_tir(
DefinedInTVMScript.get_global_var("conv1d"),
(x, w),
out_sinfo=R.Tensor((2, 64, 13), dtype="float32"),
)
return gv
for gvar, func in DefinedInTVMScript.functions.items():
if gvar.name_hint != "main":
AddedAfterward[gvar] = func
tvm.ir.assert_structural_equal(DefinedInTVMScript, AddedAfterward)
```
(There's a slightly easier way to copy over top-level functions, after
https://github.com/apache/tvm/pull/15703, but it doesn't work for the
`cls.name_of_prim_func` notation.)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]