adfwer233 opened a new pull request, #15389:
URL: https://github.com/apache/tvm/pull/15389

   This PR introduces a new dlight rule to tensorize the matrix multiplication.
   
   With this PR, the end-to-end performance of LLaMA remarkablely improved in 
my 2080-ti target.
   
   - Profiling result before (seq_len = 1000)
   
   ```
   Time elapsed: encoding 0.6713087558746338 seconds, decoding 
0.028337717056274414 secs
   Profiling...
   ======================= Encoding Profiling =======================
   Name                                              Time (ms)   Count   Total 
time (ms)   Pct (%)   Memory (MB)     Bandwidth (GB/s)  Shape
   fused_NT_matmul3_silu1                            4.4564      32      
142.6060          20.15     114.84          25.1649           (1, 1001, 4096), 
(11008, 4096), (1, 1001, 11008)
   NT_matmul1                                        1.1672      96      
112.0528          15.83     47.64           39.8590           (1, 1001, 4096), 
(4096, 4096), (1, 1001, 4096)
   fused_NT_matmul3_multiply1                        3.3154      32      
106.0936          14.99     135.85          40.0161           (1, 1001, 4096), 
(11008, 4096), (1, 1001, 11008), (1, 1001, 11008)
   fused_NT_matmul4_add1                             3.2164      32      
102.9257          14.54     122.66          37.2410           (1, 1001, 11008), 
(4096, 11008), (1, 1001, 4096), (1, 1001, 4096)
   fused_NT_matmul2_divide2_maximum1_minimum1_cast3  1.7297      32      
55.3503           7.82      139.87          78.9666           (1, 32, 1001, 
128), (1, 32, 1001, 128), (1, 1, 1001, 1001), (1, 32, 1001, 1001)
   fused_NT_matmul1_add1                             1.3677      32      
43.7674           6.18      55.46           39.5992           (1, 1001, 4096), 
(4096, 4096), (1, 1001, 4096), (1, 1001, 4096)
   fused_decode3                                     0.5820      64      
37.2469           5.26      105.46          176.9668          (824, 11008), 
(103, 11008), (11008, 4096)
   matmul8                                           1.0648      32      
34.0737           4.81      76.80           70.4336           (1, 32, 1001, 
1001), (1, 32, 1001, 128), (1, 32, 1001, 128)
   fused_decode2                                     0.2292      128     
29.3383           4.15      39.24           167.1969          (824, 4096), 
(103, 4096), (4096, 4096)
   fused_decode4                                     0.6302      32      
20.1676           2.85      105.41          163.3289          (2208, 4096), 
(276, 4096), (4096, 11008)
   fused_softmax2_cast4                              0.4280      32      
13.6972           1.94      183.47          418.5889          (1, 32, 1001, 
1001), (1, 32, 1001, 1001)
   transpose2                                        0.0594      96      5.7022 
           0.81      15.64           257.1472          (1, 1001, 32, 128), (1, 
32, 1001, 128)
   rms_norm                                          0.0446      65      2.8969 
           0.41      15.65           342.8895          (1, 1001, 4096), 
(4096,), (1, 1001, 4096)
   transpose5                                        0.0433      32      1.3840 
           0.20      15.64           353.1500          (1, 32, 1001, 128), (1, 
1001, 32, 128)
   fused_fused_decode9_fused_matmul6_cast2           0.4230      1       0.4230 
           0.06      56.71           130.9365          (824, 32000), (103, 
32000), (1, 1, 4096), (1, 1, 32000)
   fused_fused_decode1_take1                         0.0315      1       0.0315 
           0.00      64.40           1994.1480         (32000, 824), (32000, 
103), (1001,), (1001, 4096)
   extend_te                                         0.0167      1       0.0167 
           0.00      3.82            223.9991          (1, 1, 1001, 1001), (1, 
1, 1001, 1001)
   slice                                             0.0026      1       0.0026 
           0.00      7.83            2990.6865         (1, 1001, 4096), (1, 1, 
4096)
   Total time: 707.7763 ms
   
   ======================= Decoding Profiling =======================
   Name                                              Time (ms)   Count   Total 
time (ms)   Pct (%)   Memory (MB)     Bandwidth (GB/s)  Shape
   fused_fused_decode6_matmul2                       0.0460      96      4.4175 
           18.00     7.26            154.0292          (824, 4096), (103, 
4096), (1, 1, 4096), (1, 1, 4096)
   fused_fused_decode7_fused_matmul4_multiply        0.1229      32      3.9333 
           16.03     19.51           155.0312          (824, 11008), (103, 
11008), (1, 1, 4096), (1, 1, 11008), (1, 1, 11008)
   fused_fused_decode8_fused_matmul5_add             0.1221      32      3.9076 
           15.92     19.44           155.4895          (2208, 4096), (276, 
4096), (1, 1, 11008), (1, 1, 4096), (1, 1, 4096)
   transpose2                                        0.0597      64      3.8228 
           15.58     15.66           255.9667          (1, 1002, 32, 128), (1, 
32, 1002, 128)
   fused_fused_decode7_fused_matmul4_silu            0.1177      32      3.7667 
           15.35     19.49           161.7167          (824, 11008), (103, 
11008), (1, 1, 4096), (1, 1, 11008)
   fused_NT_matmul_divide_maximum_minimum_cast       0.0554      32      1.7729 
           7.22      7.96            140.3085          (1, 32, 1, 128), (1, 32, 
1002, 128), (1, 1, 1, 1002), (1, 32, 1, 1002)
   fused_fused_decode6_fused_matmul2_add             0.0470      32      1.5045 
           6.13      7.27            150.9115          (824, 4096), (103, 
4096), (1, 1, 4096), (1, 1, 4096), (1, 1, 4096)
   matmul3                                           0.0180      32      0.5765 
           2.35      7.90            428.0822          (1, 32, 1, 1002), (1, 
32, 1002, 128), (1, 32, 1, 128)
   rms_norm1                                         0.0059      65      0.3846 
           1.57      0.02            3.8683            (1, 1, 4096), (4096,), 
(1, 1, 4096)
   fused_fused_decode9_fused_matmul6_cast2           0.3578      1       0.3578 
           1.46      56.71           154.7936          (824, 32000), (103, 
32000), (1, 1, 4096), (1, 1, 32000)
   fused_softmax_cast1                               0.0029      32      0.0914 
           0.37      0.18            62.7517           (1, 32, 1, 1002), (1, 
32, 1, 1002)
   full                                              0.0028      1       0.0028 
           0.01      0.00            0.6785            (1, 1, 1, 1002)
   fused_fused_decode1_take                          0.0026      1       0.0026 
           0.01      56.59           21041.0947        (32000, 824), (32000, 
103), (1,), (1, 4096)
   Total time: 24.5409 ms
   ```
   
   - Profiling result with this PR (seq_len = 1000)
   
   ```
   Time elapsed: encoding 2.130176544189453 seconds, decoding 
0.027759075164794922 secs
   Profiling...
   ======================= Encoding Profiling =======================
   Name                                              Time (ms)   Count   Total 
time (ms)   Pct (%)   Memory (MB)     Bandwidth (GB/s)  Shape
   NT_matmul1                                        5.0671      96      
486.4369          22.26     47.64           9.1817            (1, 1001, 4096), 
(4096, 4096), (1, 1001, 4096)
   fused_NT_matmul4_add1                             14.0175     32      
448.5585          20.53     122.66          8.5453            (1, 1001, 11008), 
(4096, 11008), (1, 1001, 4096), (1, 1001, 4096)
   fused_NT_matmul3_multiply1                        13.8555     32      
443.3747          20.29     135.85          9.5753            (1, 1001, 4096), 
(11008, 4096), (1, 1001, 11008), (1, 1001, 11008)
   fused_NT_matmul3_silu1                            13.6471     32      
436.7056          19.99     114.84          8.2176            (1, 1001, 4096), 
(11008, 4096), (1, 1001, 11008)
   fused_NT_matmul1_add1                             5.2856      32      
169.1377          7.74      55.46           10.2470           (1, 1001, 4096), 
(4096, 4096), (1, 1001, 4096), (1, 1001, 4096)
   fused_NT_matmul2_divide2_maximum1_minimum1_cast3  1.6962      32      
54.2783           2.48      139.87          80.5262           (1, 32, 1001, 
128), (1, 32, 1001, 128), (1, 1, 1001, 1001), (1, 32, 1001, 1001)
   fused_decode3                                     0.5787      64      
37.0366           1.70      105.46          177.9717          (824, 11008), 
(103, 11008), (11008, 4096)
   matmul8                                           1.1080      32      
35.4559           1.62      76.80           67.6879           (1, 32, 1001, 
1001), (1, 32, 1001, 128), (1, 32, 1001, 128)
   fused_decode2                                     0.2291      128     
29.3250           1.34      39.24           167.2727          (824, 4096), 
(103, 4096), (4096, 4096)
   fused_decode4                                     0.6377      32      
20.4066           0.93      105.41          161.4159          (2208, 4096), 
(276, 4096), (4096, 11008)
   fused_softmax2_cast4                              0.4283      32      
13.7044           0.63      183.47          418.3677          (1, 32, 1001, 
1001), (1, 32, 1001, 1001)
   transpose2                                        0.0598      96      5.7364 
           0.26      15.64           255.6140          (1, 1001, 32, 128), (1, 
32, 1001, 128)
   rms_norm                                          0.0446      65      2.8968 
           0.13      15.65           342.8976          (1, 1001, 4096), 
(4096,), (1, 1001, 4096)
   transpose5                                        0.0433      32      1.3843 
           0.06      15.64           353.0760          (1, 32, 1001, 128), (1, 
1001, 32, 128)
   fused_fused_decode9_fused_matmul6_cast2           0.4007      1       0.4007 
           0.02      56.71           138.2198          (824, 32000), (103, 
32000), (1, 1, 4096), (1, 1, 32000)
   fused_fused_decode1_take1                         0.0316      1       0.0316 
           0.00      64.40           1990.9568         (32000, 824), (32000, 
103), (1001,), (1001, 4096)
   extend_te                                         0.0167      1       0.0167 
           0.00      3.82            223.7772          (1, 1, 1001, 1001), (1, 
1, 1001, 1001)
   slice                                             0.0027      1       0.0027 
           0.00      7.83            2825.6112         (1, 1001, 4096), (1, 1, 
4096)
   Total time: 2184.8893 ms
   
   ======================= Decoding Profiling =======================
   Name                                              Time (ms)   Count   Total 
time (ms)   Pct (%)   Memory (MB)     Bandwidth (GB/s)  Shape
   fused_fused_decode6_matmul2                       0.0460      96      4.4193 
           18.00     7.26            153.9649          (824, 4096), (103, 
4096), (1, 1, 4096), (1, 1, 4096)
   fused_fused_decode7_fused_matmul4_multiply        0.1244      32      3.9821 
           16.22     19.51           153.1303          (824, 11008), (103, 
11008), (1, 1, 4096), (1, 1, 11008), (1, 1, 11008)
   fused_fused_decode8_fused_matmul5_add             0.1218      32      3.8965 
           15.87     19.44           155.9304          (2208, 4096), (276, 
4096), (1, 1, 11008), (1, 1, 4096), (1, 1, 4096)
   transpose2                                        0.0600      64      3.8421 
           15.65     15.66           254.6841          (1, 1002, 32, 128), (1, 
32, 1002, 128)
   fused_fused_decode7_fused_matmul4_silu            0.1168      32      3.7381 
           15.23     19.49           162.9531          (824, 11008), (103, 
11008), (1, 1, 4096), (1, 1, 11008)
   fused_NT_matmul_divide_maximum_minimum_cast       0.0554      32      1.7733 
           7.22      7.96            140.2760          (1, 32, 1, 128), (1, 32, 
1002, 128), (1, 1, 1, 1002), (1, 32, 1, 1002)
   fused_fused_decode6_fused_matmul2_add             0.0463      32      1.4826 
           6.04      7.27            153.1449          (824, 4096), (103, 
4096), (1, 1, 4096), (1, 1, 4096), (1, 1, 4096)
   matmul3                                           0.0179      32      0.5713 
           2.33      7.90            431.9594          (1, 32, 1, 1002), (1, 
32, 1002, 128), (1, 32, 1, 128)
   rms_norm1                                         0.0059      65      0.3827 
           1.56      0.02            3.8878            (1, 1, 4096), (4096,), 
(1, 1, 4096)
   fused_fused_decode9_fused_matmul6_cast2           0.3687      1       0.3687 
           1.50      56.71           150.2006          (824, 32000), (103, 
32000), (1, 1, 4096), (1, 1, 32000)
   fused_softmax_cast1                               0.0027      32      0.0859 
           0.35      0.18            66.7356           (1, 32, 1, 1002), (1, 
32, 1, 1002)
   full                                              0.0029      1       0.0029 
           0.01      0.00            0.6545            (1, 1, 1, 1002)
   fused_fused_decode1_take                          0.0026      1       0.0026 
           0.01      56.59           20890.1001        (32000, 824), (32000, 
103), (1,), (1, 4096)
   Total time: 24.5482 ms
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to