adfwer233 opened a new pull request, #15389:
URL: https://github.com/apache/tvm/pull/15389
This PR introduces a new dlight rule to tensorize the matrix multiplication.
With this PR, the end-to-end performance of LLaMA remarkablely improved in
my 2080-ti target.
- Profiling result before (seq_len = 1000)
```
Time elapsed: encoding 0.6713087558746338 seconds, decoding
0.028337717056274414 secs
Profiling...
======================= Encoding Profiling =======================
Name Time (ms) Count Total
time (ms) Pct (%) Memory (MB) Bandwidth (GB/s) Shape
fused_NT_matmul3_silu1 4.4564 32
142.6060 20.15 114.84 25.1649 (1, 1001, 4096),
(11008, 4096), (1, 1001, 11008)
NT_matmul1 1.1672 96
112.0528 15.83 47.64 39.8590 (1, 1001, 4096),
(4096, 4096), (1, 1001, 4096)
fused_NT_matmul3_multiply1 3.3154 32
106.0936 14.99 135.85 40.0161 (1, 1001, 4096),
(11008, 4096), (1, 1001, 11008), (1, 1001, 11008)
fused_NT_matmul4_add1 3.2164 32
102.9257 14.54 122.66 37.2410 (1, 1001, 11008),
(4096, 11008), (1, 1001, 4096), (1, 1001, 4096)
fused_NT_matmul2_divide2_maximum1_minimum1_cast3 1.7297 32
55.3503 7.82 139.87 78.9666 (1, 32, 1001,
128), (1, 32, 1001, 128), (1, 1, 1001, 1001), (1, 32, 1001, 1001)
fused_NT_matmul1_add1 1.3677 32
43.7674 6.18 55.46 39.5992 (1, 1001, 4096),
(4096, 4096), (1, 1001, 4096), (1, 1001, 4096)
fused_decode3 0.5820 64
37.2469 5.26 105.46 176.9668 (824, 11008),
(103, 11008), (11008, 4096)
matmul8 1.0648 32
34.0737 4.81 76.80 70.4336 (1, 32, 1001,
1001), (1, 32, 1001, 128), (1, 32, 1001, 128)
fused_decode2 0.2292 128
29.3383 4.15 39.24 167.1969 (824, 4096),
(103, 4096), (4096, 4096)
fused_decode4 0.6302 32
20.1676 2.85 105.41 163.3289 (2208, 4096),
(276, 4096), (4096, 11008)
fused_softmax2_cast4 0.4280 32
13.6972 1.94 183.47 418.5889 (1, 32, 1001,
1001), (1, 32, 1001, 1001)
transpose2 0.0594 96 5.7022
0.81 15.64 257.1472 (1, 1001, 32, 128), (1,
32, 1001, 128)
rms_norm 0.0446 65 2.8969
0.41 15.65 342.8895 (1, 1001, 4096),
(4096,), (1, 1001, 4096)
transpose5 0.0433 32 1.3840
0.20 15.64 353.1500 (1, 32, 1001, 128), (1,
1001, 32, 128)
fused_fused_decode9_fused_matmul6_cast2 0.4230 1 0.4230
0.06 56.71 130.9365 (824, 32000), (103,
32000), (1, 1, 4096), (1, 1, 32000)
fused_fused_decode1_take1 0.0315 1 0.0315
0.00 64.40 1994.1480 (32000, 824), (32000,
103), (1001,), (1001, 4096)
extend_te 0.0167 1 0.0167
0.00 3.82 223.9991 (1, 1, 1001, 1001), (1,
1, 1001, 1001)
slice 0.0026 1 0.0026
0.00 7.83 2990.6865 (1, 1001, 4096), (1, 1,
4096)
Total time: 707.7763 ms
======================= Decoding Profiling =======================
Name Time (ms) Count Total
time (ms) Pct (%) Memory (MB) Bandwidth (GB/s) Shape
fused_fused_decode6_matmul2 0.0460 96 4.4175
18.00 7.26 154.0292 (824, 4096), (103,
4096), (1, 1, 4096), (1, 1, 4096)
fused_fused_decode7_fused_matmul4_multiply 0.1229 32 3.9333
16.03 19.51 155.0312 (824, 11008), (103,
11008), (1, 1, 4096), (1, 1, 11008), (1, 1, 11008)
fused_fused_decode8_fused_matmul5_add 0.1221 32 3.9076
15.92 19.44 155.4895 (2208, 4096), (276,
4096), (1, 1, 11008), (1, 1, 4096), (1, 1, 4096)
transpose2 0.0597 64 3.8228
15.58 15.66 255.9667 (1, 1002, 32, 128), (1,
32, 1002, 128)
fused_fused_decode7_fused_matmul4_silu 0.1177 32 3.7667
15.35 19.49 161.7167 (824, 11008), (103,
11008), (1, 1, 4096), (1, 1, 11008)
fused_NT_matmul_divide_maximum_minimum_cast 0.0554 32 1.7729
7.22 7.96 140.3085 (1, 32, 1, 128), (1, 32,
1002, 128), (1, 1, 1, 1002), (1, 32, 1, 1002)
fused_fused_decode6_fused_matmul2_add 0.0470 32 1.5045
6.13 7.27 150.9115 (824, 4096), (103,
4096), (1, 1, 4096), (1, 1, 4096), (1, 1, 4096)
matmul3 0.0180 32 0.5765
2.35 7.90 428.0822 (1, 32, 1, 1002), (1,
32, 1002, 128), (1, 32, 1, 128)
rms_norm1 0.0059 65 0.3846
1.57 0.02 3.8683 (1, 1, 4096), (4096,),
(1, 1, 4096)
fused_fused_decode9_fused_matmul6_cast2 0.3578 1 0.3578
1.46 56.71 154.7936 (824, 32000), (103,
32000), (1, 1, 4096), (1, 1, 32000)
fused_softmax_cast1 0.0029 32 0.0914
0.37 0.18 62.7517 (1, 32, 1, 1002), (1,
32, 1, 1002)
full 0.0028 1 0.0028
0.01 0.00 0.6785 (1, 1, 1, 1002)
fused_fused_decode1_take 0.0026 1 0.0026
0.01 56.59 21041.0947 (32000, 824), (32000,
103), (1,), (1, 4096)
Total time: 24.5409 ms
```
- Profiling result with this PR (seq_len = 1000)
```
Time elapsed: encoding 2.130176544189453 seconds, decoding
0.027759075164794922 secs
Profiling...
======================= Encoding Profiling =======================
Name Time (ms) Count Total
time (ms) Pct (%) Memory (MB) Bandwidth (GB/s) Shape
NT_matmul1 5.0671 96
486.4369 22.26 47.64 9.1817 (1, 1001, 4096),
(4096, 4096), (1, 1001, 4096)
fused_NT_matmul4_add1 14.0175 32
448.5585 20.53 122.66 8.5453 (1, 1001, 11008),
(4096, 11008), (1, 1001, 4096), (1, 1001, 4096)
fused_NT_matmul3_multiply1 13.8555 32
443.3747 20.29 135.85 9.5753 (1, 1001, 4096),
(11008, 4096), (1, 1001, 11008), (1, 1001, 11008)
fused_NT_matmul3_silu1 13.6471 32
436.7056 19.99 114.84 8.2176 (1, 1001, 4096),
(11008, 4096), (1, 1001, 11008)
fused_NT_matmul1_add1 5.2856 32
169.1377 7.74 55.46 10.2470 (1, 1001, 4096),
(4096, 4096), (1, 1001, 4096), (1, 1001, 4096)
fused_NT_matmul2_divide2_maximum1_minimum1_cast3 1.6962 32
54.2783 2.48 139.87 80.5262 (1, 32, 1001,
128), (1, 32, 1001, 128), (1, 1, 1001, 1001), (1, 32, 1001, 1001)
fused_decode3 0.5787 64
37.0366 1.70 105.46 177.9717 (824, 11008),
(103, 11008), (11008, 4096)
matmul8 1.1080 32
35.4559 1.62 76.80 67.6879 (1, 32, 1001,
1001), (1, 32, 1001, 128), (1, 32, 1001, 128)
fused_decode2 0.2291 128
29.3250 1.34 39.24 167.2727 (824, 4096),
(103, 4096), (4096, 4096)
fused_decode4 0.6377 32
20.4066 0.93 105.41 161.4159 (2208, 4096),
(276, 4096), (4096, 11008)
fused_softmax2_cast4 0.4283 32
13.7044 0.63 183.47 418.3677 (1, 32, 1001,
1001), (1, 32, 1001, 1001)
transpose2 0.0598 96 5.7364
0.26 15.64 255.6140 (1, 1001, 32, 128), (1,
32, 1001, 128)
rms_norm 0.0446 65 2.8968
0.13 15.65 342.8976 (1, 1001, 4096),
(4096,), (1, 1001, 4096)
transpose5 0.0433 32 1.3843
0.06 15.64 353.0760 (1, 32, 1001, 128), (1,
1001, 32, 128)
fused_fused_decode9_fused_matmul6_cast2 0.4007 1 0.4007
0.02 56.71 138.2198 (824, 32000), (103,
32000), (1, 1, 4096), (1, 1, 32000)
fused_fused_decode1_take1 0.0316 1 0.0316
0.00 64.40 1990.9568 (32000, 824), (32000,
103), (1001,), (1001, 4096)
extend_te 0.0167 1 0.0167
0.00 3.82 223.7772 (1, 1, 1001, 1001), (1,
1, 1001, 1001)
slice 0.0027 1 0.0027
0.00 7.83 2825.6112 (1, 1001, 4096), (1, 1,
4096)
Total time: 2184.8893 ms
======================= Decoding Profiling =======================
Name Time (ms) Count Total
time (ms) Pct (%) Memory (MB) Bandwidth (GB/s) Shape
fused_fused_decode6_matmul2 0.0460 96 4.4193
18.00 7.26 153.9649 (824, 4096), (103,
4096), (1, 1, 4096), (1, 1, 4096)
fused_fused_decode7_fused_matmul4_multiply 0.1244 32 3.9821
16.22 19.51 153.1303 (824, 11008), (103,
11008), (1, 1, 4096), (1, 1, 11008), (1, 1, 11008)
fused_fused_decode8_fused_matmul5_add 0.1218 32 3.8965
15.87 19.44 155.9304 (2208, 4096), (276,
4096), (1, 1, 11008), (1, 1, 4096), (1, 1, 4096)
transpose2 0.0600 64 3.8421
15.65 15.66 254.6841 (1, 1002, 32, 128), (1,
32, 1002, 128)
fused_fused_decode7_fused_matmul4_silu 0.1168 32 3.7381
15.23 19.49 162.9531 (824, 11008), (103,
11008), (1, 1, 4096), (1, 1, 11008)
fused_NT_matmul_divide_maximum_minimum_cast 0.0554 32 1.7733
7.22 7.96 140.2760 (1, 32, 1, 128), (1, 32,
1002, 128), (1, 1, 1, 1002), (1, 32, 1, 1002)
fused_fused_decode6_fused_matmul2_add 0.0463 32 1.4826
6.04 7.27 153.1449 (824, 4096), (103,
4096), (1, 1, 4096), (1, 1, 4096), (1, 1, 4096)
matmul3 0.0179 32 0.5713
2.33 7.90 431.9594 (1, 32, 1, 1002), (1,
32, 1002, 128), (1, 32, 1, 128)
rms_norm1 0.0059 65 0.3827
1.56 0.02 3.8878 (1, 1, 4096), (4096,),
(1, 1, 4096)
fused_fused_decode9_fused_matmul6_cast2 0.3687 1 0.3687
1.50 56.71 150.2006 (824, 32000), (103,
32000), (1, 1, 4096), (1, 1, 32000)
fused_softmax_cast1 0.0027 32 0.0859
0.35 0.18 66.7356 (1, 32, 1, 1002), (1,
32, 1, 1002)
full 0.0029 1 0.0029
0.01 0.00 0.6545 (1, 1, 1, 1002)
fused_fused_decode1_take 0.0026 1 0.0026
0.01 56.59 20890.1001 (32000, 824), (32000,
103), (1,), (1, 4096)
Total time: 24.5482 ms
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]