ConvolutedDog commented on code in PR #18120:
URL: https://github.com/apache/tvm/pull/18120#discussion_r2196306711


##########
tests/python/relax/test_dataflow_rewriter.py:
##########
@@ -1169,9 +1171,9 @@ def expected(
             ],
         )
 
-        v = embedded_qkv_tuple[2]
         q_embed = embedded_qkv_tuple[0]
         k_embed = embedded_qkv_tuple[1]
+        v = embedded_qkv_tuple[2]

Review Comment:
   Thanks for your explanation, but in my version, the following three IRModule 
are shown for `before`, `after` and `expected`.
   
   ```py
   # from tvm.script import relax as R
   
   @R.function(private=True)
   def main(state: R.Tensor((4096,), dtype="float32"), proj_qkv: 
R.Tensor((12288, 4096), dtype="float32"), kv_cache: R.Object) -> 
R.Tensor((4096,)):
       qkv: R.Tensor((12288,), dtype="float32") = R.matmul(proj_qkv, state, 
out_dtype="void")
       qkv_tuple: R.Tuple(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), 
dtype="float32"), R.Tensor((4096,), dtype="float32")) = R.split(qkv, 
indices_or_sections=3, axis=0)
       q: R.Tensor((4096,), dtype="float32") = qkv_tuple[0]
       k: R.Tensor((4096,), dtype="float32") = qkv_tuple[1]
       v: R.Tensor((4096,), dtype="float32") = qkv_tuple[2]
       q_embed: R.Tensor((4096,), dtype="float32") = 
R.call_pure_packed("rotary_embedding", (q,), sinfo_args=(R.Tensor((4096,), 
dtype="float32"),))
       k_embed: R.Tensor((4096,), dtype="float32") = 
R.call_pure_packed("rotary_embedding", (k,), sinfo_args=(R.Tensor((4096,), 
dtype="float32"),))
       attention: R.Tensor((4096,)) = 
R.call_pure_packed("compute_self_attention", (q_embed, k_embed, v, kv_cache), 
sinfo_args=(R.Tensor((4096,)),))
       return attention
   
   # from tvm.script import relax as R
   
   @R.function(private=True)
   def main(state: R.Tensor((4096,), dtype="float32"), proj_qkv: 
R.Tensor((12288, 4096), dtype="float32"), kv_cache: R.Object) -> 
R.Tensor((4096,)):
       qkv: R.Tensor((12288,), dtype="float32") = R.matmul(proj_qkv, state, 
out_dtype="void")
       gv: R.Tuple(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), 
dtype="float32"), R.Tensor((4096,), dtype="float32")) = 
R.call_pure_packed("split_rotary_embedding", (qkv,), 
sinfo_args=(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), 
dtype="float32"), R.Tensor((4096,), dtype="float32")))
       v: R.Tensor((4096,), dtype="float32") = gv[2]
       q_embed: R.Tensor((4096,), dtype="float32") = gv[0]
       k_embed: R.Tensor((4096,), dtype="float32") = gv[1]
       attention: R.Tensor((4096,)) = 
R.call_pure_packed("compute_self_attention", (q_embed, k_embed, v, kv_cache), 
sinfo_args=(R.Tensor((4096,)),))
       return attention
   
   # from tvm.script import relax as R
   
   @R.function(private=True)
   def main(state: R.Tensor((4096,), dtype="float32"), proj_qkv: 
R.Tensor((12288, 4096), dtype="float32"), kv_cache: R.Object) -> 
R.Tensor((4096,)):
       qkv: R.Tensor((12288,), dtype="float32") = R.matmul(proj_qkv, state, 
out_dtype="void")
       embedded_qkv_tuple: R.Tuple(R.Tensor((4096,), dtype="float32"), 
R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), dtype="float32")) = 
R.call_pure_packed("split_rotary_embedding", (qkv,), 
sinfo_args=(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), 
dtype="float32"), R.Tensor((4096,), dtype="float32")))
       v: R.Tensor((4096,), dtype="float32") = embedded_qkv_tuple[2]
       q_embed: R.Tensor((4096,), dtype="float32") = embedded_qkv_tuple[0]
       k_embed: R.Tensor((4096,), dtype="float32") = embedded_qkv_tuple[1]
       attention: R.Tensor((4096,)) = 
R.call_pure_packed("compute_self_attention", (q_embed, k_embed, v, kv_cache), 
sinfo_args=(R.Tensor((4096,)),))
       return attention
   ```
   
   Result of running `pytest`:
   
   ```sh
   (tvm-build-venv) [user1@localhost tvm-git]$ python3 -m pytest -s -v 
tests/python/relax/test_dataflow_rewriter.py::test_rewrite_of_implicit_tuple_with_three_elements
   enabled targets: llvm; cuda; nvptx
   pytest marker: 
   
======================================================================================================
 test session starts 
======================================================================================================
   platform linux -- Python 3.11.13, pytest-8.4.0, pluggy-1.6.0 -- 
/home/user1/miniconda3/envs/tvm-build-venv/bin/python3
   cachedir: .pytest_cache
   rootdir: /home/user1/Github/tvm-git
   configfile: pyproject.toml
   collected 1 item                                                             
                                                                                
                                                                   
   
   
tests/python/relax/test_dataflow_rewriter.py::test_rewrite_of_implicit_tuple_with_three_elements
 PASSED
   
   
=======================================================================================================
 1 passed in 0.14s 
=======================================================================================================
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to