masahi commented on PR #14798:
URL: https://github.com/apache/tvm/pull/14798#issuecomment-1537543264

   But a weird thing is, if I run the same attention workload via the cutlass 
example, it shows that the same kernel runs in 1.3 msec, see below (compared to 
our BYOC result, 2.4 msec). 
   
   ```
   $ nsys nvprof 
examples/41_fused_multi_head_attention/41_fused_multi_head_attention_fixed_seqlen
 --head_number=8 --batch_size=2 --head_size=40 --_head_size_v=40 
--seq_length=4096 --seq_length_kv=4096
                                                                                
                                                             
   CUTLASS Attention:                                                           
                                                                         
====================================================                            
                                                                      
        {seq length Q, seq length KV, head size, head size V, head number, 
batch size} = {4096, 4096, 40, 40, 8, 2}.
                                                                              
       Runtime: 1.36964 ms
       GFLOPs: 19897.7                                                          
                                                                         
                                                                                
                                                                         
   Passed                                                                       
                                                                         
   
    Time (%)  Total Time (ns)  Instances   Avg (ns)     Med (ns)    Min (ns)   
Max (ns)   StdDev (ns)     GridXYZ         BlockXYZ                       
                                 Name                                           
      
   --------  ---------------  ---------  -----------  -----------  ---------  
---------  -----------  --------------  --------------  ------------------
        ...
        13.8       30,160,277         22  1,370,921.7  1,368,225.0  1,346,113  
1,414,209     16,718.9    64    8    2    32    4    1  void 
attention_kernel_batched_impl<AttentionKernel<cutlass::half_t, 
cutlass::arch::Sm80, (bool)1, (…
   ```
   
   I've also checked out Triton and Flash attention perf on the same workload 
by running 
https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py#L330
 and saw perf around 1.2 - 1.3 msec. So I want to believe that ~1.3 msec should 
be the right result for an attention kernel on this workload. 
   
   So maybe there is something off in how we use this kernel from our BYOC? I 
compared [the generated 
code](https://gist.github.com/masahi/767e84ee17e9648621ff48251fa5bba5) and the 
cutlass example code but didn't find any difference. There is not difference in 
the nvcc options that might affect performance other than this `NDEBUG` stuff 
(that's how I found about it). Any thoughts? @vinx13 @cyx-6 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to