This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5c80691c81 [Dlight] Enhance vectorization loading weight for gemv
(#16878)
5c80691c81 is described below
commit 5c80691c81070df0d79fa22f64579945f4807c5e
Author: Wuwei Lin <[email protected]>
AuthorDate: Sat Apr 13 11:48:00 2024 -0700
[Dlight] Enhance vectorization loading weight for gemv (#16878)
* [Dlight] Enhance vectorization loading weight for gemv
* Update gemv.py
---
python/tvm/dlight/gpu/gemv.py | 18 ++++++------
tests/python/dlight/test_gpu_gemv.py | 57 ++++++++++++++++++------------------
2 files changed, 38 insertions(+), 37 deletions(-)
diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py
index 55b38fc66b..c1ce876620 100644
--- a/python/tvm/dlight/gpu/gemv.py
+++ b/python/tvm/dlight/gpu/gemv.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""A rule for GEMV and DecodeGEMV."""
-import re
from functools import reduce
from typing import List, Optional, Union
@@ -56,10 +55,9 @@ def get_extent(sch: tir.Schedule, loop_rv:
tir.schedule.LoopRV):
def get_bytes(dtype: Union[DataType, str]) -> int:
- num = re.findall(r"\d+", dtype)
- if len(num) != 1:
- raise ValueError(f"Cannot get bytes from {dtype}")
- return int(num[0]) // 8
+ if isinstance(dtype, str):
+ dtype = DataType(dtype)
+ return dtype.bits * dtype.lanes // 8
def is_gemv(sch: tir.Schedule, block_info: BlockInfo) ->
Optional[List[tir.Buffer]]:
@@ -297,10 +295,11 @@ class GEMV(GPUScheduleRule):
Aq_local = sch.cache_read(rf, read_buffer_index=1,
storage_scope="local")
sch.compute_at(Aq_local, r, preserve_unit_loops=True)
s_local, r_local = sch.get_loops(block=Aq_local)[-2:]
- s_local, vec_load = sch.split(
- s_local, factors=[None, VEC_LOAD], preserve_unit_iters=True
+ fused_load = sch.fuse(s_local, r_local)
+ aq_vec_len = max(1, VEC_LOAD //
get_bytes(sch.get(Aq_local).reads[0].buffer.dtype))
+ fused_load, vec_load = sch.split(
+ fused_load, factors=[None, aq_vec_len],
preserve_unit_iters=True
)
- sch.reorder(s_local, r_local, vec_load) # either s_local or
r_local should be 1
sch.vectorize(vec_load)
# load vector into shared memory, shape should be the whole vector
@@ -442,10 +441,12 @@ class GEMV(GPUScheduleRule):
TAG_S, TAG_R = "threadIdx.y", "threadIdx.x"
SUPPORT_WARP_SHUFFLE = False
+ VEC_LOAD = 1
if target.kind.name == "cuda":
VEC_C = 4
LOAD_V_SHARED = True
LOAD_V_VEC = 8
+ VEC_LOAD = 4
UNROLL = 256
SUPPORT_WARP_SHUFFLE = True
if isinstance(len_S, int):
@@ -522,7 +523,6 @@ class GEMV(GPUScheduleRule):
else max(get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8])
// TR, 1),
)
VEC_C = min(get_max_factor(TILE_R, [1, 2, 4, 8]), VEC_C)
- VEC_LOAD = 1
return apply(
sch,
diff --git a/tests/python/dlight/test_gpu_gemv.py
b/tests/python/dlight/test_gpu_gemv.py
index 8903babbc0..0fd7f79159 100644
--- a/tests/python/dlight/test_gpu_gemv.py
+++ b/tests/python/dlight/test_gpu_gemv.py
@@ -120,13 +120,13 @@ class TestGEMV(BaseBeforeAfter):
T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
0, v0, 0, v1])
var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
0, v0, 0, v1] = T.float16(0)
for ax2_fused_u_fused_0 in T.serial(1,
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
- for ax0, ax1, ax2_0, ax3 in T.grid(1, 1, 1, 2):
- for ax2_1 in T.vectorized(1):
+ for ax0, ax1, ax2_ax3_fused_0 in T.grid(1, 1, 1):
+ for ax2_ax3_fused_1 in T.vectorized(2):
with T.block("lv1638_local"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(32,
ax0_fused_ax1_fused_fused_0 // n + ax1)
- v2 = T.axis.spatial(n,
ax0_fused_ax1_fused_fused_0 % n + ax2_0 + ax2_1)
- v3 = T.axis.spatial(128,
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + ax3)
+ v2 = T.axis.spatial(n,
ax0_fused_ax1_fused_fused_0 % n)
+ v3 = T.axis.spatial(128,
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + ax2_ax3_fused_0 * 2 +
ax2_ax3_fused_1)
T.reads(lv1638[v0, v1, v2, v3])
T.writes(lv1638_local[v0, v1, v2, v3])
lv1638_local[v0, v1, v2, v3] = lv1638[v0,
v1, v2, v3]
@@ -224,11 +224,11 @@ def test_decode_gemv_256_threads():
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
0, 0, v0])
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
0, 0, v0] = T.float16(0)
for ax1_0_fused_ax1_1_fused_0 in T.serial(32,
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
- for ax0_0, ax1 in T.grid(1, 1):
+ for ax0_ax1_fused in T.serial(1):
for ax0_1 in T.vectorized(1):
with T.block("lv571_local"):
- v0 = T.axis.spatial(22016,
u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
- v1 = T.axis.spatial(512,
ax1_0_fused_ax1_1_fused_0 * 16 +
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
+ v0 = T.axis.spatial(22016,
u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1)
+ v1 = T.axis.spatial(512,
ax1_0_fused_ax1_1_fused_0 * 16 +
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0)
T.reads(lv571[v0, v1])
T.writes(lv571_local[v0, v1])
lv571_local[v0, v1] = lv571[v0, v1]
@@ -332,11 +332,11 @@ def test_decode_gemv1():
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
0, 0, v0])
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
0, 0, v0] = T.float16(0)
for ax1_0_fused_ax1_1_fused_0 in T.serial(8,
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
- for ax0_0, ax1 in T.grid(1, 1):
- for ax0_1 in T.vectorized(1):
+ for ax0_ax1_fused_0 in range(1):
+ for ax0_ax1_fused_1 in T.vectorized(1):
with T.block("lv571_local"):
- v0 = T.axis.spatial(22016,
u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
- v1 = T.axis.spatial(512,
ax1_0_fused_ax1_1_fused_0 * 64 +
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
+ v0 = T.axis.spatial(22016,
u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1)
+ v1 = T.axis.spatial(512,
ax1_0_fused_ax1_1_fused_0 * 64 +
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0)
T.reads(lv571[v0, v1])
T.writes(lv571_local[v0, v1])
lv571_local[v0, v1] = lv571[v0, v1]
@@ -448,11 +448,11 @@ def test_decode_gemv2():
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
0, 0, v0])
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
0, 0, v0] = T.float16(0)
for ax1_0_fused_ax1_1_fused_0 in T.serial(8,
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
- for ax0_0, ax1 in T.grid(1, 1):
- for ax0_1 in T.vectorized(1):
+ for ax0_ax1_fused_0 in range(1):
+ for ax0_ax1_fused_1 in T.vectorized(1):
with T.block("lv771_local"):
- v0 = T.axis.spatial(32000,
u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
- v1 = T.axis.spatial(512,
ax1_0_fused_ax1_1_fused_0 * 64 +
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
+ v0 = T.axis.spatial(32000,
u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1)
+ v1 = T.axis.spatial(512,
ax1_0_fused_ax1_1_fused_0 * 64 +
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0)
T.reads(lv771[v0, v1])
T.writes(lv771_local[v0, v1])
lv771_local[v0, v1] = lv771[v0, v1]
@@ -572,11 +572,11 @@ def test_decode_gemv3():
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
T.int64(0), T.int64(0), v0])
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
T.int64(0), T.int64(0), v0] = T.float16(0)
for ax1_0_fused_ax1_1_fused_0 in T.serial(T.int64(43),
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
- for ax0_0, ax1 in T.grid(T.int64(1), T.int64(1)):
- for ax0_1 in T.vectorized(T.int64(1)):
+ for ax0_ax1_fused_0 in range(T.int64(1)):
+ for ax0_ax1_fused_1 in T.vectorized(T.int64(1)):
with T.block("lv575_local"):
- v0 = T.axis.spatial(T.int64(4096),
u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + ax0_0 +
ax0_1)
- v1 = T.axis.spatial(T.int64(1376),
ax1_0_fused_ax1_1_fused_0 * T.int64(32) +
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
+ v0 = T.axis.spatial(T.int64(4096),
u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1)
+ v1 = T.axis.spatial(T.int64(1376),
ax1_0_fused_ax1_1_fused_0 * T.int64(32) +
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0)
T.reads(lv575[v0, v1])
T.writes(lv575_local[v0, v1])
lv575_local[v0, v1] = lv575[v0, v1]
@@ -942,15 +942,16 @@ def test_blockized_gemv():
T.writes(o_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused,
v_expert_id_o, v0])
o_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, v_expert_id_o, v0] =
T.float16(0)
for ax1_fused_u_fused_0 in T.serial(32,
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
- for ax0, ax1_0, ax2 in T.grid(1, 1, 8):
- for ax1_1 in T.vectorized(1):
- with T.block("w_local"):
- v0 = T.axis.spatial(1, ax0)
- v1 = T.axis.spatial(16384,
u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + ax1_0 + ax1_1)
- v2 = T.axis.spatial(4096,
ax1_fused_u_fused_0 * 128 + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * 8
+ ax2)
- T.reads(w[indptr[v_expert_id_o] +
v0, v1, v2])
- T.writes(w_local[v0, v1, v2])
- w_local[v0, v1, v2] =
w[indptr[v_expert_id_o] + v0, v1, v2]
+ for ax0 in range(1):
+ for ax1_ax2_fused_0 in range(8):
+ for ax1_ax2_fused_1 in T.vectorized(1):
+ with T.block("w_local"):
+ v0 = T.axis.spatial(1, ax0)
+ v1 = T.axis.spatial(16384,
u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1)
+ v2 = T.axis.spatial(4096,
ax1_fused_u_fused_0 * 128 + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * 8
+ ax1_ax2_fused_0 + ax1_ax2_fused_1)
+
T.reads(w[indptr[v_expert_id_o] + v0, v1, v2])
+ T.writes(w_local[v0, v1, v2])
+ w_local[v0, v1, v2] =
w[indptr[v_expert_id_o] + v0, v1, v2]
for u_fused_ax0_fused_fused_2,
ax1_fused_u_fused_2 in T.grid(1, 8):
for
ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 in T.vectorized(1):
with T.block("gemv_rf_update"):