Lunderberg commented on code in PR #16094:
URL: https://github.com/apache/tvm/pull/16094#discussion_r1393282812


##########
tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py:
##########
@@ -1060,16 +1738,19 @@ def test_mlp_pipeline_parallelism():
 
 
 def test_decoder_layer():
-    # mod = relax.transform.LegalizeOps({"relax.reshape": lambda bb, call: 
bb.normalize(call)})(LlamaAttentionLayer)
-    mod = LlamaAttentionLayer
-    after = relax.distributed.transform.PropagateSharding()(mod)
+    after = 
relax.distributed.transform.PropagateSharding()(LlamaAttentionLayer)
     assert_structural_equal(after, ShardedLlamaAttentionLayer)
 
 
-def test_decoder_layer_dynamic_shape():
-    # mod = relax.transform.LegalizeOps({"relax.reshape": lambda bb, call: 
bb.normalize(call)})(LlamaAttentionLayer)
-    mod = LlamaAttentionLayerDynamicShape
+def test_decoder_layer_tir():
+    mod = relax.transform.LegalizeOps()(LlamaAttentionLayer)

Review Comment:
   > It's an interesting topic to allow the split axis to be dynamic, but 
currently we just enforce that the axis has to be static for simplicity.
   
   I was thinking that it would be the other way around.  In order to correctly 
split a `PrimFunc`, the function must contain enough information to produce a 
smaller version of itself.  That would either require the split axis to be 
dynamic, so size changes are handled automatically, or to undo any 
constant-folding that has been done using the split size.
   
   > How do I construct ShardedLlamaAttentionLayerTIR by not writing all the 
TIR primfuncs?
   
   The current way to do so is by replacing the `cls.FUNC_NAME` with 
`LlamaAttentionLayer.get_global_var("FUNC_NAME")` within the TVMScript, then 
copying the PrimFunc to it afterwards (example below).  Most of the time, the 
extra steps make it not worth the effort, but here I think it is worthwhile to 
show that every `PrimFunc` is expected to be identical.
   
   ```python
   def test_function():
       @tvm.script.ir_module
       class DefinedInTVMScript:
           @R.function
           def main(
               x: R.Tensor((2, 128, 28), dtype="float32"), w: R.Tensor((64, 16, 
3), dtype="float32")
           ) -> R.Tensor((2, 64, 13), dtype="float32"):
               gv = R.call_tir(Before.conv1d, (x, w), out_sinfo=R.Tensor((2, 
64, 13), dtype="float32"))
               return gv
   
           @T.prim_func(private=True)
           def conv1d(
               A: T.Buffer((T.int64(2), T.int64(128), T.int64(28)), "float32"),
               B: T.Buffer((T.int64(64), T.int64(16), T.int64(3)), "float32"),
               group_conv1d_ncw: T.Buffer((T.int64(2), T.int64(64), 
T.int64(13)), "float32"),
           ):
               pass
   
       @tvm.script.ir_module
       class AddedAfterward:
           @R.function
           def main(
               x: R.Tensor((2, 128, 28), dtype="float32"), w: R.Tensor((64, 16, 
3), dtype="float32")
           ) -> R.Tensor((2, 64, 13), dtype="float32"):
               gv = R.call_tir(
                   DefinedInTVMScript.get_global_var("conv1d"),
                   (x, w),
                   out_sinfo=R.Tensor((2, 64, 13), dtype="float32"),
               )
               return gv
   
       for gvar, func in DefinedInTVMScript.functions.items():
           if gvar.name_hint != "main":
               AddedAfterward[gvar] = func
   
       tvm.ir.assert_structural_equal(DefinedInTVMScript, AddedAfterward)
   ```
   
   (There's a slightly easier way to copy over top-level functions, after 
https://github.com/apache/tvm/pull/15703, but it doesn't work for the 
`cls.name_of_prim_func` notation.)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to