Lunderberg commented on code in PR #16094:
URL: https://github.com/apache/tvm/pull/16094#discussion_r1388099077
##########
tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py:
##########
@@ -68,13 +68,119 @@ def foo(
x, weight1, out_dtype="void"
)
lv1: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") =
R.nn.gelu(lv0)
- lv2: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = lv1
lv3: R.DTensor((128, 128), "float32", "mesh[0]", "R") = R.matmul(
- lv2, weight2, out_dtype="void"
+ lv1, weight2, out_dtype="void"
)
return lv3
[email protected]_module
Review Comment:
Nitpick: Parsing of TVMScript should be done within the test body (if used
for a single test case), or within a helper function called by the body (if
used by multiple test cases). If the TVMScript is at module-scope, parsing
occurs when the module is imported. This causes the entire pytest suite to
halt at collection-time. Keeping TVMScript parsing inside the test body causes
the failure to occur within the execution of that unit test.
This is a large usability issue when debugging the TVMScript parser, as it
becomes impossible to tell how many test failures exist, let alone if there are
any commonalities between them.
##########
tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py:
##########
@@ -1060,16 +1738,19 @@ def test_mlp_pipeline_parallelism():
def test_decoder_layer():
- # mod = relax.transform.LegalizeOps({"relax.reshape": lambda bb, call:
bb.normalize(call)})(LlamaAttentionLayer)
- mod = LlamaAttentionLayer
- after = relax.distributed.transform.PropagateSharding()(mod)
+ after =
relax.distributed.transform.PropagateSharding()(LlamaAttentionLayer)
assert_structural_equal(after, ShardedLlamaAttentionLayer)
-def test_decoder_layer_dynamic_shape():
- # mod = relax.transform.LegalizeOps({"relax.reshape": lambda bb, call:
bb.normalize(call)})(LlamaAttentionLayer)
- mod = LlamaAttentionLayerDynamicShape
+def test_decoder_layer_tir():
+ mod = relax.transform.LegalizeOps()(LlamaAttentionLayer)
Review Comment:
With the length of the expected output in `ShardedLlamaAttentionLayerTIR`, I
can't tell what behavior this test is intended to check. Can we add a
docstring describing how TIR calls should be handled when propagating the
sharding?
My initial guess would be that they are treated as opaque, because the
shapes used by `PrimFunc`s in`ShardedLlamaAttentionLayerTIR` are unsharded, but
that isn't readily apparent to a reader.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]