ConvolutedDog commented on code in PR #18120:
URL: https://github.com/apache/tvm/pull/18120#discussion_r2194552647


##########
tests/python/relax/test_dataflow_rewriter.py:
##########
@@ -1169,9 +1171,9 @@ def expected(
             ],
         )
 
-        v = embedded_qkv_tuple[2]
         q_embed = embedded_qkv_tuple[0]
         k_embed = embedded_qkv_tuple[1]
+        v = embedded_qkv_tuple[2]

Review Comment:
   May I ask what version of TVM you use? In the 0.21.dev0 version on my 
machine, the rewritten code is:
   ```py
   @R.function(private=True)
   def main(state: R.Tensor((4096,), dtype="float32"), proj_qkv: 
R.Tensor((12288, 4096), dtype="float32"), kv_cache: R.Object) -> 
R.Tensor((4096,)):
   qkv: R.Tensor((12288,), dtype="float32") = R.matmul(proj_qkv, state, 
out_dtype="void")
   gv: R.Tuple(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), 
dtype="float32"), R.Tensor((4096,), dtype="float32")) = 
R.call_pure_packed("split_rotary_embedding", (qkv,), 
sinfo_args=(R.Tensor((4096,), dtype="float32"), R.Tensor((4096,), 
dtype="float32"), R.Tensor((4096,), dtype="float32"))) 
   v: R.Tensor((4096,), dtype="float32") = gv[2] 
   q_embed: R.Tensor((4096,), dtype="float32") = gv[0] 
   k_embed: R.Tensor((4096,), dtype="float32") = gv[1] 
   attention: R.Tensor((4096,)) = R.call_pure_packed("compute_self_attention", 
(q_embed, k_embed, v, kv_cache), sinfo_args=(R.Tensor((4096,)),)) 
   return attention
   ```
   So it has structural equality to the `expected`. And I have debug the 
backward C++ code and has not reproduce the problem you mentioned.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to