This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5c80691c81 [Dlight] Enhance vectorization loading weight for gemv 
(#16878)
5c80691c81 is described below

commit 5c80691c81070df0d79fa22f64579945f4807c5e
Author: Wuwei Lin <[email protected]>
AuthorDate: Sat Apr 13 11:48:00 2024 -0700

    [Dlight] Enhance vectorization loading weight for gemv (#16878)
    
    * [Dlight] Enhance vectorization loading weight for gemv
    
    
    * Update gemv.py
---
 python/tvm/dlight/gpu/gemv.py        | 18 ++++++------
 tests/python/dlight/test_gpu_gemv.py | 57 ++++++++++++++++++------------------
 2 files changed, 38 insertions(+), 37 deletions(-)

diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py
index 55b38fc66b..c1ce876620 100644
--- a/python/tvm/dlight/gpu/gemv.py
+++ b/python/tvm/dlight/gpu/gemv.py
@@ -15,7 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 """A rule for GEMV and DecodeGEMV."""
-import re
 from functools import reduce
 from typing import List, Optional, Union
 
@@ -56,10 +55,9 @@ def get_extent(sch: tir.Schedule, loop_rv: 
tir.schedule.LoopRV):
 
 
 def get_bytes(dtype: Union[DataType, str]) -> int:
-    num = re.findall(r"\d+", dtype)
-    if len(num) != 1:
-        raise ValueError(f"Cannot get bytes from {dtype}")
-    return int(num[0]) // 8
+    if isinstance(dtype, str):
+        dtype = DataType(dtype)
+    return dtype.bits * dtype.lanes // 8
 
 
 def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> 
Optional[List[tir.Buffer]]:
@@ -297,10 +295,11 @@ class GEMV(GPUScheduleRule):
             Aq_local = sch.cache_read(rf, read_buffer_index=1, 
storage_scope="local")
             sch.compute_at(Aq_local, r, preserve_unit_loops=True)
             s_local, r_local = sch.get_loops(block=Aq_local)[-2:]
-            s_local, vec_load = sch.split(
-                s_local, factors=[None, VEC_LOAD], preserve_unit_iters=True
+            fused_load = sch.fuse(s_local, r_local)
+            aq_vec_len = max(1, VEC_LOAD // 
get_bytes(sch.get(Aq_local).reads[0].buffer.dtype))
+            fused_load, vec_load = sch.split(
+                fused_load, factors=[None, aq_vec_len], 
preserve_unit_iters=True
             )
-            sch.reorder(s_local, r_local, vec_load)  # either s_local or 
r_local should be 1
             sch.vectorize(vec_load)
 
             # load vector into shared memory, shape should be the whole vector
@@ -442,10 +441,12 @@ class GEMV(GPUScheduleRule):
 
         TAG_S, TAG_R = "threadIdx.y", "threadIdx.x"
         SUPPORT_WARP_SHUFFLE = False
+        VEC_LOAD = 1
         if target.kind.name == "cuda":
             VEC_C = 4
             LOAD_V_SHARED = True
             LOAD_V_VEC = 8
+            VEC_LOAD = 4
             UNROLL = 256
             SUPPORT_WARP_SHUFFLE = True
             if isinstance(len_S, int):
@@ -522,7 +523,6 @@ class GEMV(GPUScheduleRule):
             else max(get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8]) 
// TR, 1),
         )
         VEC_C = min(get_max_factor(TILE_R, [1, 2, 4, 8]), VEC_C)
-        VEC_LOAD = 1
 
         return apply(
             sch,
diff --git a/tests/python/dlight/test_gpu_gemv.py 
b/tests/python/dlight/test_gpu_gemv.py
index 8903babbc0..0fd7f79159 100644
--- a/tests/python/dlight/test_gpu_gemv.py
+++ b/tests/python/dlight/test_gpu_gemv.py
@@ -120,13 +120,13 @@ class TestGEMV(BaseBeforeAfter):
                                 
T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
 0, v0, 0, v1])
                                 
var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
 0, v0, 0, v1] = T.float16(0)
                     for ax2_fused_u_fused_0 in T.serial(1, 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
-                        for ax0, ax1, ax2_0, ax3 in T.grid(1, 1, 1, 2):
-                            for ax2_1 in T.vectorized(1):
+                        for ax0, ax1, ax2_ax3_fused_0 in T.grid(1, 1, 1):
+                            for ax2_ax3_fused_1 in T.vectorized(2):
                                 with T.block("lv1638_local"):
                                     v0 = T.axis.spatial(1, ax0)
                                     v1 = T.axis.spatial(32, 
ax0_fused_ax1_fused_fused_0 // n + ax1)
-                                    v2 = T.axis.spatial(n, 
ax0_fused_ax1_fused_fused_0 % n + ax2_0 + ax2_1)
-                                    v3 = T.axis.spatial(128, 
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + ax3)
+                                    v2 = T.axis.spatial(n, 
ax0_fused_ax1_fused_fused_0 % n)
+                                    v3 = T.axis.spatial(128, 
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + ax2_ax3_fused_0 * 2 + 
ax2_ax3_fused_1)
                                     T.reads(lv1638[v0, v1, v2, v3])
                                     T.writes(lv1638_local[v0, v1, v2, v3])
                                     lv1638_local[v0, v1, v2, v3] = lv1638[v0, 
v1, v2, v3]
@@ -224,11 +224,11 @@ def test_decode_gemv_256_threads():
                                 
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 0, 0, v0])
                                 
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 0, 0, v0] = T.float16(0)
                     for ax1_0_fused_ax1_1_fused_0 in T.serial(32, 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
-                        for ax0_0, ax1 in T.grid(1, 1):
+                        for ax0_ax1_fused in T.serial(1):
                             for ax0_1 in T.vectorized(1):
                                 with T.block("lv571_local"):
-                                    v0 = T.axis.spatial(22016, 
u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
-                                    v1 = T.axis.spatial(512, 
ax1_0_fused_ax1_1_fused_0 * 16 + 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
+                                    v0 = T.axis.spatial(22016, 
u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1)
+                                    v1 = T.axis.spatial(512, 
ax1_0_fused_ax1_1_fused_0 * 16 + 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0)
                                     T.reads(lv571[v0, v1])
                                     T.writes(lv571_local[v0, v1])
                                     lv571_local[v0, v1] = lv571[v0, v1]
@@ -332,11 +332,11 @@ def test_decode_gemv1():
                                 
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 0, 0, v0])
                                 
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 0, 0, v0] = T.float16(0)
                     for ax1_0_fused_ax1_1_fused_0 in T.serial(8, 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
-                        for ax0_0, ax1 in T.grid(1, 1):
-                            for ax0_1 in T.vectorized(1):
+                        for ax0_ax1_fused_0 in range(1):
+                            for ax0_ax1_fused_1 in T.vectorized(1):
                                 with T.block("lv571_local"):
-                                    v0 = T.axis.spatial(22016, 
u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
-                                    v1 = T.axis.spatial(512, 
ax1_0_fused_ax1_1_fused_0 * 64 + 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
+                                    v0 = T.axis.spatial(22016, 
u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1)
+                                    v1 = T.axis.spatial(512, 
ax1_0_fused_ax1_1_fused_0 * 64 + 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0)
                                     T.reads(lv571[v0, v1])
                                     T.writes(lv571_local[v0, v1])
                                     lv571_local[v0, v1] = lv571[v0, v1]
@@ -448,11 +448,11 @@ def test_decode_gemv2():
                                 
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 0, 0, v0])
                                 
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 0, 0, v0] = T.float16(0)
                     for ax1_0_fused_ax1_1_fused_0 in T.serial(8, 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
-                        for ax0_0, ax1 in T.grid(1, 1):
-                            for ax0_1 in T.vectorized(1):
+                        for ax0_ax1_fused_0 in range(1):
+                            for ax0_ax1_fused_1 in T.vectorized(1):
                                 with T.block("lv771_local"):
-                                    v0 = T.axis.spatial(32000, 
u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
-                                    v1 = T.axis.spatial(512, 
ax1_0_fused_ax1_1_fused_0 * 64 + 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
+                                    v0 = T.axis.spatial(32000, 
u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1)
+                                    v1 = T.axis.spatial(512, 
ax1_0_fused_ax1_1_fused_0 * 64 + 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0)
                                     T.reads(lv771[v0, v1])
                                     T.writes(lv771_local[v0, v1])
                                     lv771_local[v0, v1] = lv771[v0, v1]
@@ -572,11 +572,11 @@ def test_decode_gemv3():
                                 
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 T.int64(0), T.int64(0), v0])
                                 
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 T.int64(0), T.int64(0), v0] = T.float16(0)
                     for ax1_0_fused_ax1_1_fused_0 in T.serial(T.int64(43), 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
-                        for ax0_0, ax1 in T.grid(T.int64(1), T.int64(1)):
-                            for ax0_1 in T.vectorized(T.int64(1)):
+                        for ax0_ax1_fused_0 in range(T.int64(1)):
+                            for ax0_ax1_fused_1 in T.vectorized(T.int64(1)):
                                 with T.block("lv575_local"):
-                                    v0 = T.axis.spatial(T.int64(4096), 
u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + ax0_0 + 
ax0_1)
-                                    v1 = T.axis.spatial(T.int64(1376), 
ax1_0_fused_ax1_1_fused_0 * T.int64(32) + 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
+                                    v0 = T.axis.spatial(T.int64(4096), 
u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1)
+                                    v1 = T.axis.spatial(T.int64(1376), 
ax1_0_fused_ax1_1_fused_0 * T.int64(32) + 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0)
                                     T.reads(lv575[v0, v1])
                                     T.writes(lv575_local[v0, v1])
                                     lv575_local[v0, v1] = lv575[v0, v1]
@@ -942,15 +942,16 @@ def test_blockized_gemv():
                                         
T.writes(o_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, 
v_expert_id_o, v0])
                                         
o_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, v_expert_id_o, v0] = 
T.float16(0)
                             for ax1_fused_u_fused_0 in T.serial(32, 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
-                                for ax0, ax1_0, ax2 in T.grid(1, 1, 8):
-                                    for ax1_1 in T.vectorized(1):
-                                        with T.block("w_local"):
-                                            v0 = T.axis.spatial(1, ax0)
-                                            v1 = T.axis.spatial(16384, 
u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + ax1_0 + ax1_1)
-                                            v2 = T.axis.spatial(4096, 
ax1_fused_u_fused_0 * 128 + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * 8 
+ ax2)
-                                            T.reads(w[indptr[v_expert_id_o] + 
v0, v1, v2])
-                                            T.writes(w_local[v0, v1, v2])
-                                            w_local[v0, v1, v2] = 
w[indptr[v_expert_id_o] + v0, v1, v2]
+                                for ax0 in range(1):
+                                    for ax1_ax2_fused_0 in range(8):
+                                        for ax1_ax2_fused_1 in T.vectorized(1):
+                                            with T.block("w_local"):
+                                                v0 = T.axis.spatial(1, ax0)
+                                                v1 = T.axis.spatial(16384, 
u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1)
+                                                v2 = T.axis.spatial(4096, 
ax1_fused_u_fused_0 * 128 + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * 8 
+ ax1_ax2_fused_0 + ax1_ax2_fused_1)
+                                                
T.reads(w[indptr[v_expert_id_o] + v0, v1, v2])
+                                                T.writes(w_local[v0, v1, v2])
+                                                w_local[v0, v1, v2] = 
w[indptr[v_expert_id_o] + v0, v1, v2]
                                 for u_fused_ax0_fused_fused_2, 
ax1_fused_u_fused_2 in T.grid(1, 8):
                                     for 
ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 in T.vectorized(1):
                                         with T.block("gemv_rf_update"):

Reply via email to