ConvolutedDog commented on code in PR #18120:
URL: https://github.com/apache/tvm/pull/18120#discussion_r2196306711
##########
tests/python/relax/test_dataflow_rewriter.py:
##########
@@ -1169,9 +1171,9 @@ def expected(
],
)
- v = embedded_qkv_tuple[2]
q_embed = embedded_qkv_tuple[0]
k_embed = embedded_qkv_tuple[1]
+ v = embedded_qkv_tuple[2]
Review Comment:
Thanks for your explanation, but in my version, the following three IRModule
are shown for `before`, `after` and `expected`.
```py
# from tvm.script import relax as R
@R.function(private=True)
def main(state: R.Tensor((4096,), dtype="float32"), proj_qkv:
R.Tensor((12288, 4096), dtype="float32"), kv_cache: R.Object) ->
R.Tensor((4096,)):
qkv: R.Tensor((12288,), dtype="float32") = R.matmul(proj_qkv, state,
out_dtype="void")
qkv_tuple: R.Tuple(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,),
dtype="float32"), R.Tensor((4096,), dtype="float32")) = R.split(qkv,
indices_or_sections=3, axis=0)
q: R.Tensor((4096,), dtype="float32") = qkv_tuple[0]
k: R.Tensor((4096,), dtype="float32") = qkv_tuple[1]
v: R.Tensor((4096,), dtype="float32") = qkv_tuple[2]
q_embed: R.Tensor((4096,), dtype="float32") =
R.call_pure_packed("rotary_embedding", (q,), sinfo_args=(R.Tensor((4096,),
dtype="float32"),))
k_embed: R.Tensor((4096,), dtype="float32") =
R.call_pure_packed("rotary_embedding", (k,), sinfo_args=(R.Tensor((4096,),
dtype="float32"),))
attention: R.Tensor((4096,)) =
R.call_pure_packed("compute_self_attention", (q_embed, k_embed, v, kv_cache),
sinfo_args=(R.Tensor((4096,)),))
return attention
# from tvm.script import relax as R
@R.function(private=True)
def main(state: R.Tensor((4096,), dtype="float32"), proj_qkv:
R.Tensor((12288, 4096), dtype="float32"), kv_cache: R.Object) ->
R.Tensor((4096,)):
qkv: R.Tensor((12288,), dtype="float32") = R.matmul(proj_qkv, state,
out_dtype="void")
gv: R.Tuple(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,),
dtype="float32"), R.Tensor((4096,), dtype="float32")) =
R.call_pure_packed("split_rotary_embedding", (qkv,),
sinfo_args=(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,),
dtype="float32"), R.Tensor((4096,), dtype="float32")))
v: R.Tensor((4096,), dtype="float32") = gv[2]
q_embed: R.Tensor((4096,), dtype="float32") = gv[0]
k_embed: R.Tensor((4096,), dtype="float32") = gv[1]
attention: R.Tensor((4096,)) =
R.call_pure_packed("compute_self_attention", (q_embed, k_embed, v, kv_cache),
sinfo_args=(R.Tensor((4096,)),))
return attention
# from tvm.script import relax as R
@R.function(private=True)
def main(state: R.Tensor((4096,), dtype="float32"), proj_qkv:
R.Tensor((12288, 4096), dtype="float32"), kv_cache: R.Object) ->
R.Tensor((4096,)):
qkv: R.Tensor((12288,), dtype="float32") = R.matmul(proj_qkv, state,
out_dtype="void")
embedded_qkv_tuple: R.Tuple(R.Tensor((4096,), dtype="float32"),
R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32")) =
R.call_pure_packed("split_rotary_embedding", (qkv,),
sinfo_args=(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,),
dtype="float32"), R.Tensor((4096,), dtype="float32")))
v: R.Tensor((4096,), dtype="float32") = embedded_qkv_tuple[2]
q_embed: R.Tensor((4096,), dtype="float32") = embedded_qkv_tuple[0]
k_embed: R.Tensor((4096,), dtype="float32") = embedded_qkv_tuple[1]
attention: R.Tensor((4096,)) =
R.call_pure_packed("compute_self_attention", (q_embed, k_embed, v, kv_cache),
sinfo_args=(R.Tensor((4096,)),))
return attention
```
Result of running `pytest`:
```sh
(tvm-build-venv) [user1@localhost tvm-git]$ python3 -m pytest -s -v
tests/python/relax/test_dataflow_rewriter.py::test_rewrite_of_implicit_tuple_with_three_elements
enabled targets: llvm; cuda; nvptx
pytest marker:
======================================================================================================
test session starts
======================================================================================================
platform linux -- Python 3.11.13, pytest-8.4.0, pluggy-1.6.0 --
/home/user1/miniconda3/envs/tvm-build-venv/bin/python3
cachedir: .pytest_cache
rootdir: /home/user1/Github/tvm-git
configfile: pyproject.toml
collected 1 item
tests/python/relax/test_dataflow_rewriter.py::test_rewrite_of_implicit_tuple_with_three_elements
PASSED
=======================================================================================================
1 passed in 0.14s
=======================================================================================================
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]