Phoslight commented on code in PR #18120:
URL: https://github.com/apache/tvm/pull/18120#discussion_r2193882619
##########
tests/python/relax/test_dataflow_rewriter.py:
##########
@@ -1169,9 +1171,9 @@ def expected(
],
)
- v = embedded_qkv_tuple[2]
q_embed = embedded_qkv_tuple[0]
k_embed = embedded_qkv_tuple[1]
+ v = embedded_qkv_tuple[2]
Review Comment:
Thank you for the review. However, `before()` was rewritten as
```
@R.function(private=True)
def main(state: R.Tensor((4096,), dtype="float32"), proj_qkv:
R.Tensor((12288, 4096), dtype="float32"), kv_cache: R.Object) ->
R.Tensor((4096,)):
qkv: R.Tensor((12288,), dtype="float32") = R.matmul(proj_qkv, state,
out_dtype="void")
gv: R.Tuple(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,),
dtype="float32"), R.Tensor((4096,), dtype="float32")) =
R.call_pure_packed("split_rotary_embedding", (qkv,),
sinfo_args=(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,),
dtype="float32"), R.Tensor((4096,), dtype="float32")))
gv8: R.Tensor((4096,), dtype="float32") = gv[0] <<<<<========
order: [0] -> [1] -> [2] here
^
gv9: R.Tensor((4096,), dtype="float32") = gv[1]
gv10: R.Tensor((4096,), dtype="float32") = gv[2]
attention: R.Tensor((4096,)) =
R.call_pure_packed("compute_self_attention", (gv8, gv9, gv10, kv_cache),
sinfo_args=(R.Tensor((4096,)),))
return attention
```
while the original version of `expect()` was rewritten as:
```
@R.function(private=True)
def main(state: R.Tensor((4096,), dtype="float32"), proj_qkv:
R.Tensor((12288, 4096), dtype="float32"), kv_cache: R.Object) ->
R.Tensor((4096,)):
qkv: R.Tensor((12288,), dtype="float32") = R.matmul(proj_qkv, state,
out_dtype="void")
embedded_qkv_tuple: R.Tuple(R.Tensor((4096,), dtype="float32"),
R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32")) =
R.call_pure_packed("split_rotary_embedding", (qkv,),
sinfo_args=(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,),
dtype="float32"), R.Tensor((4096,), dtype="float32")))
v: R.Tensor((4096,), dtype="float32") = embedded_qkv_tuple[2]
<<<<<======== the original order: [2] -> [0] -> [1]
^
q_embed: R.Tensor((4096,), dtype="float32") = embedded_qkv_tuple[0]
k_embed: R.Tensor((4096,), dtype="float32") = embedded_qkv_tuple[1]
attention: R.Tensor((4096,)) =
R.call_pure_packed("compute_self_attention", (q_embed, k_embed, v, kv_cache),
sinfo_args=(R.Tensor((4096,)),))
return attention
```
I just run the test again - if we don't switch the order of these lines the
test will fail due to a `tvm.ir.assert_structural_equal(expected, after)`
assertion failure.
Please let me know if this explanation sounds reasonable to you, or if
further clarification is needed.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]