Lunderberg commented on code in PR #16094:
URL: https://github.com/apache/tvm/pull/16094#discussion_r1391733985
##########
tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py:
##########
@@ -1060,16 +1738,19 @@ def test_mlp_pipeline_parallelism():
def test_decoder_layer():
- # mod = relax.transform.LegalizeOps({"relax.reshape": lambda bb, call:
bb.normalize(call)})(LlamaAttentionLayer)
- mod = LlamaAttentionLayer
- after = relax.distributed.transform.PropagateSharding()(mod)
+ after =
relax.distributed.transform.PropagateSharding()(LlamaAttentionLayer)
assert_structural_equal(after, ShardedLlamaAttentionLayer)
-def test_decoder_layer_dynamic_shape():
- # mod = relax.transform.LegalizeOps({"relax.reshape": lambda bb, call:
bb.normalize(call)})(LlamaAttentionLayer)
- mod = LlamaAttentionLayerDynamicShape
+def test_decoder_layer_tir():
+ mod = relax.transform.LegalizeOps()(LlamaAttentionLayer)
Review Comment:
If a test case fails, it is important that a developer can tell what break
in functionality occurred, and where that occurred. Considering
`test_decoder_layer_tir`, I like the brevity of applying `PropagateSharding`
followed immediately by `assert_structural_equal`, as that immediately
identifies where a breakage occurred. However, with the 550-line input (lines
812 - 1371) and the 600-line expected output (lines 1373 - 1983), a reader
cannot easily tell where the differences are between the two.
I understand the desire to test a wide variety of inputs, and agree with
that goal. We also need to consider how to best write maintainable unit tests.
A failing unit test is often the first time a developer will see a particular
test file, and is the audience that must be able to read, understand, and (if
necessary) update the unit test.
> The shapes are not changed because we are changing Tensor to DTensor
which contains an extra field placement.
Regarding the shapes, my apologies as I wasn't referring to the shape in the
`R.Tensor` or `R.DTensor`, but the shape in the `T.Buffer` of the PrimFunc.
For example, consider `lv8` in `test_decoder_layer_foo`. Before the
transform, the RHS has struct info `R.Tensor((4096, 4096), dtype="float16")`
and is passed into `cls.matmul` which expects shape `[4096, 4096]`.
After the transform, the RHS has struct info `R.DTensor((4096, 4096),
"float16", "mesh[0]", "S[1]")`. With `"mesh[0]", we know that there are two
workers, so the size of the local tensor on each GPU is `[4096, 4096 //
num_workers]`, or `[4096, 2048]`. However, the `cls.matmul` PrimFunc still
expects the RHS to have shape `[4096, 4096]`.
That's why I was asking how the TIR shapes are being handled, as this seems
like a clear discrepancy between the tensor shape provided by the relax
callsite and the buffer shape accepted by the PrimFunc callee.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]