Lunderberg commented on code in PR #16094:
URL: https://github.com/apache/tvm/pull/16094#discussion_r1388099077


##########
tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py:
##########
@@ -68,13 +68,119 @@ def foo(
             x, weight1, out_dtype="void"
         )
         lv1: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = 
R.nn.gelu(lv0)
-        lv2: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = lv1
         lv3: R.DTensor((128, 128), "float32", "mesh[0]", "R") = R.matmul(
-            lv2, weight2, out_dtype="void"
+            lv1, weight2, out_dtype="void"
         )
         return lv3
 
 
[email protected]_module

Review Comment:
   Nitpick: Parsing of TVMScript should be done within the test body (if used 
for a single test case), or within a helper function called by the body (if 
used by multiple test cases).  If the TVMScript is at module-scope, parsing 
occurs when the module is imported.  This causes the entire pytest suite to 
halt at collection-time.  Keeping TVMScript parsing inside the test body causes 
the failure to occur within the execution of that unit test.
   
   This is a large usability issue when debugging the TVMScript parser, as it 
becomes impossible to tell how many test failures exist, let alone if there are 
any commonalities between them.



##########
tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py:
##########
@@ -1060,16 +1738,19 @@ def test_mlp_pipeline_parallelism():
 
 
 def test_decoder_layer():
-    # mod = relax.transform.LegalizeOps({"relax.reshape": lambda bb, call: 
bb.normalize(call)})(LlamaAttentionLayer)
-    mod = LlamaAttentionLayer
-    after = relax.distributed.transform.PropagateSharding()(mod)
+    after = 
relax.distributed.transform.PropagateSharding()(LlamaAttentionLayer)
     assert_structural_equal(after, ShardedLlamaAttentionLayer)
 
 
-def test_decoder_layer_dynamic_shape():
-    # mod = relax.transform.LegalizeOps({"relax.reshape": lambda bb, call: 
bb.normalize(call)})(LlamaAttentionLayer)
-    mod = LlamaAttentionLayerDynamicShape
+def test_decoder_layer_tir():
+    mod = relax.transform.LegalizeOps()(LlamaAttentionLayer)

Review Comment:
   With the length of the expected output in `ShardedLlamaAttentionLayerTIR`, I 
can't tell what behavior this test is intended to check.  Can we add a 
docstring describing how TIR calls should be handled when propagating the 
sharding?
   
   My initial guess would be that they are treated as opaque, because the 
shapes used by `PrimFunc`s in`ShardedLlamaAttentionLayerTIR` are unsharded, but 
that isn't readily apparent to a reader.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to