This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e5f85c0e32 [DLIGHT][ADRENO] Fix for opencl adreno matmul schedule 
(#17259)
e5f85c0e32 is described below

commit e5f85c0e32046b6b1bdc5bd1a2485c645df4e730
Author: krishnaraj36 <[email protected]>
AuthorDate: Sat Aug 10 21:55:51 2024 +0530

    [DLIGHT][ADRENO] Fix for opencl adreno matmul schedule (#17259)
    
    Fixed the matmul schedule for the case of epilog blocks
---
 python/tvm/dlight/gpu/matmul.py        | 50 ++++++++++++++-----
 tests/python/dlight/test_gpu_matmul.py | 89 ++++++++++++++++++----------------
 2 files changed, 85 insertions(+), 54 deletions(-)

diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py
index 25cc649b44..5fb8e2469d 100644
--- a/python/tvm/dlight/gpu/matmul.py
+++ b/python/tvm/dlight/gpu/matmul.py
@@ -941,7 +941,7 @@ class Matmul(GPUScheduleRule):
                 inner_x=False,
             )
         elif target.kind.name == "opencl" and (
-            ("android" in str(target.host)) or ("windows" in str(target.host))
+            ("android" in str(target.host)) or ("adreno" in str(target.attrs))
         ):
             return Matmul.Config(
                 block_size_x=32,
@@ -991,7 +991,10 @@ class Matmul(GPUScheduleRule):
             end_it = block_stmt.reads[-1].region[-1].min
             return {it.var: it.kind for it in iter_infos}.get(end_it, "O") == 
"R"
 
-        if target.kind.name == "opencl" and not is_inner_reduction(block_stmt, 
iter_infos):
+        if (
+            target.kind.name == "opencl"
+            and (("android" in str(target.host)) or ("adreno" in 
str(target.attrs)))
+        ) and not is_inner_reduction(block_stmt, iter_infos):
             ret = self.sch_outer_reduction(sch, config, main_block, blocks)
             if ret is not None:
                 return ret
@@ -1122,6 +1125,16 @@ class Matmul(GPUScheduleRule):
         reduction_block: tir.schedule.BlockRV,
         blocks: List[tir.schedule.BlockRV],
     ) -> Optional[tir.Schedule]:
+
+        """Get vectorization factor"""
+
+        def get_max_factor(n, factors):
+            factors = sorted(factors, reverse=True)
+            for factor in factors:
+                if n % factor == 0:
+                    return factor
+            return 1
+
         reduction_loops = sch.get_loops(reduction_block)
         if not len(reduction_loops) == 4:
             return None
@@ -1140,13 +1153,17 @@ class Matmul(GPUScheduleRule):
             config.vector_size,
             config.unroll,
         )
-
-        is_dequant_block = len(blocks) > 1
-        if is_dequant_block:
-            compute_block, dequant_block, matmul_block = blocks
-            sch.compute_inline(compute_block)
-        else:
-            (matmul_block,) = blocks
+        VecSize = min(get_max_factor(sch.get(n).extent // Threads_X, [1, 2, 4, 
8]), VecSize)
+        dequant_block = None
+        matmul_block = reduction_block
+        epilogue_block = None
+        if blocks[-1] is not matmul_block:
+            epilogue_block = blocks[-1]
+        for blk in blocks[:-1]:
+            if "dequantize" in sch.get(blk).name_hint:
+                dequant_block = blk
+            elif blk is not matmul_block:
+                sch.compute_inline(blk)
 
         m = sch.fuse(mb, ms)
 
@@ -1162,12 +1179,13 @@ class Matmul(GPUScheduleRule):
         sch.reorder(no, mo, ni, mi, k0, k1, k2, k3, mu, nv)
 
         sch.compute_at(rmat_block, k0)
-        if is_dequant_block:
+        if dequant_block is not None:
             sch.compute_at(dequant_block, k3)
         sch.reverse_compute_at(wmat_block, mi)
         sch.set_scope(rmat_block, 0, "shared")
         sch.set_scope(matmul_block, 0, "local")
-        if is_dequant_block:
+
+        if dequant_block is not None:
             sch.set_scope(dequant_block, 0, "local")
 
         sch.bind(mo, "blockIdx.y")
@@ -1175,7 +1193,7 @@ class Matmul(GPUScheduleRule):
         sch.bind(mi, "threadIdx.y")
         sch.bind(ni, "threadIdx.x")
         sch.vectorize(sch.get_loops(matmul_block)[-1])
-        if is_dequant_block:
+        if dequant_block is not None:
             sch.vectorize(sch.get_loops(dequant_block)[-1])
 
         # Co-operative Memory Fetch
@@ -1187,7 +1205,7 @@ class Matmul(GPUScheduleRule):
         sch.vectorize(wv)
 
         # Scale and Quant Cache
-        if is_dequant_block:
+        if dequant_block is not None:
             qb = sch.cache_read(dequant_block, 0, "local")
             sb = sch.cache_read(dequant_block, 1, "local")
             sch.compute_at(sb, k1)
@@ -1197,5 +1215,11 @@ class Matmul(GPUScheduleRule):
             sch.vectorize(sch.get_loops(qb)[-1])
             sch.vectorize(sch.get_loops(sb)[-1])
 
+        if epilogue_block is not None:
+            sch.reverse_compute_at(epilogue_block, mi, 
preserve_unit_loops=True)
+            sch.set_scope(wmat_block, 0, "local")
+            sch.compute_inline(wmat_block)
+            sch.vectorize(sch.get_loops(epilogue_block)[-1])
+
         sch.decompose_reduction(matmul_block, k0)
         return sch
diff --git a/tests/python/dlight/test_gpu_matmul.py 
b/tests/python/dlight/test_gpu_matmul.py
index 4cef7f1c27..dc5276e62a 100644
--- a/tests/python/dlight/test_gpu_matmul.py
+++ b/tests/python/dlight/test_gpu_matmul.py
@@ -685,47 +685,54 @@ class TestMatmulAndroid(AndroidBeforeAfter):
 class TestFusedDequantMatmulAndroid(AndroidBeforeAfter):
     # fmt: off
     @T.prim_func
-    def before(lv840: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), 
lv841: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm260: 
T.handle, p_output0: T.handle):
+    def before(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), 
lv453: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm130: 
T.handle, transformer_h_0_attn_c_attn_bias3: T.Buffer((T.int64(12288),), 
"float16"), p_output0: T.handle):
         T.func_attr({"tir.noalias": T.bool(True)})
         seq_len = T.int64()
-        rms_norm260 = T.match_buffer(p_rms_norm260, (T.int64(1), seq_len, 
T.int64(4096)), "float16")
-        matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, 
T.int64(12288)), "float16")
+        rms_norm130 = T.match_buffer(p_rms_norm130, (T.int64(1), seq_len, 
T.int64(4096)), "float16")
+        T_add_intermediate_intermediate = T.match_buffer(p_output0, 
(T.int64(1), seq_len, T.int64(12288)), "float16")
         # with T.block("root"):
         compute = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16")
         dequantize_intermediate_intermediate = T.alloc_buffer((T.int64(4096), 
T.int64(12288)), "float16")
+        matmul_intermediate = T.alloc_buffer((T.int64(1), seq_len, 
T.int64(12288)), "float16")
         for i0, i1 in T.grid(T.int64(4096), T.int64(12288)):
             with T.block("compute"):
                 v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
-                T.reads(lv840[v_i0 // T.int64(8), v_i1])
+                T.reads(lv452[v_i0 // T.int64(8), v_i1])
                 T.writes(compute[v_i0, v_i1])
-                compute[v_i0, v_i1] = T.Cast("float16", 
T.bitwise_and(T.shift_right(lv840[v_i0 // T.int64(8), v_i1], T.Cast("uint32", 
v_i0 % T.int64(8) * T.int64(4))), T.uint32(15)))
+                compute[v_i0, v_i1] = T.Cast("float16", 
T.bitwise_and(T.shift_right(lv452[v_i0 // T.int64(8), v_i1], T.Cast("uint32", 
v_i0 % T.int64(8) * T.int64(4))), T.uint32(15)))
         for i0, i1 in T.grid(T.int64(4096), T.int64(12288)):
             with T.block("dequantize"):
                 v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
-                T.reads(compute[v_i0, v_i1], lv841[v_i0 // T.int64(32), v_i1])
+                T.reads(compute[v_i0, v_i1], lv453[v_i0 // T.int64(32), v_i1])
                 T.writes(dequantize_intermediate_intermediate[v_i0, v_i1])
-                dequantize_intermediate_intermediate[v_i0, v_i1] = 
(compute[v_i0, v_i1] - T.float16(7)) * lv841[v_i0 // T.int64(32), v_i1]
+                dequantize_intermediate_intermediate[v_i0, v_i1] = 
(compute[v_i0, v_i1] - T.float16(7)) * lv453[v_i0 // T.int64(32), v_i1]
         for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(12288), 
T.int64(4096)):
             with T.block("matmul"):
                 v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
-                T.reads(rms_norm260[v_i0, v_i1, v_k], 
dequantize_intermediate_intermediate[v_k, v_i2])
+                T.reads(rms_norm130[v_i0, v_i1, v_k], 
dequantize_intermediate_intermediate[v_k, v_i2])
                 T.writes(matmul_intermediate[v_i0, v_i1, v_i2])
                 with T.init():
                     matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
-                matmul_intermediate[v_i0, v_i1, v_i2] = 
matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm260[v_i0, v_i1, v_k] * 
dequantize_intermediate_intermediate[v_k, v_i2]
+                matmul_intermediate[v_i0, v_i1, v_i2] = 
matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm130[v_i0, v_i1, v_k] * 
dequantize_intermediate_intermediate[v_k, v_i2]
+        for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(12288)):
+            with T.block("T_add"):
+                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                T.reads(matmul_intermediate[v_ax0, v_ax1, v_ax2], 
transformer_h_0_attn_c_attn_bias3[v_ax2])
+                T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2])
+                T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2] = 
matmul_intermediate[v_ax0, v_ax1, v_ax2] + 
transformer_h_0_attn_c_attn_bias3[v_ax2]
 
     @T.prim_func
-    def expected(lv840: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), 
lv841: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm260: 
T.handle, p_output0: T.handle):
+    def expected(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), 
lv453: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm130: 
T.handle, transformer_h_0_attn_c_attn_bias3: T.Buffer((T.int64(12288),), 
"float16"), p_output0: T.handle):
         T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, 
"tir.noalias": T.bool(True)})
         seq_len = T.int64()
-        rms_norm260 = T.match_buffer(p_rms_norm260, (T.int64(1), seq_len, 
T.int64(4096)), "float16")
-        matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, 
T.int64(12288)), "float16")
+        rms_norm130 = T.match_buffer(p_rms_norm130, (T.int64(1), seq_len, 
T.int64(4096)), "float16")
+        T_add_intermediate_intermediate = T.match_buffer(p_output0, 
(T.int64(1), seq_len, T.int64(12288)), "float16")
         # with T.block("root"):
         dequantize_intermediate_intermediate_local = 
T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16", scope="local")
-        rms_norm260_pad_shared = T.alloc_buffer((T.int64(1), (seq_len + 
T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), "float16", 
scope="shared")
+        rms_norm130_pad_shared = T.alloc_buffer((T.int64(1), (seq_len + 
T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), "float16", 
scope="shared")
         matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), (seq_len + 
T.int64(31)) // T.int64(32) * T.int64(32), T.int64(12288)), "float16", 
scope="local")
-        lv840_local = T.alloc_buffer((T.int64(512), T.int64(12288)), "uint32", 
scope="local")
-        lv841_local = T.alloc_buffer((T.int64(128), T.int64(12288)), 
"float16", scope="local")
+        lv452_local = T.alloc_buffer((T.int64(512), T.int64(12288)), "uint32", 
scope="local")
+        lv453_local = T.alloc_buffer((T.int64(128), T.int64(12288)), 
"float16", scope="local")
         for i2_0 in T.thread_binding(T.int64(48), thread="blockIdx.x"):
             for i0_i1_fused_0 in T.thread_binding((seq_len + T.int64(31)) // 
T.int64(32), thread="blockIdx.y"):
                 for i2_1 in T.thread_binding(T.int64(32), 
thread="threadIdx.x"):
@@ -743,37 +750,37 @@ class TestFusedDequantMatmulAndroid(AndroidBeforeAfter):
                             for ax0 in range(T.int64(4)):
                                 for ax1_0 in T.thread_binding(T.int64(32), 
thread="threadIdx.x"):
                                     for ax1_1 in T.vectorized(T.int64(8)):
-                                        with T.block("rms_norm260_pad"):
+                                        with T.block("rms_norm130_pad"):
                                             v0 = T.axis.spatial(T.int64(1), 
T.int64(0))
                                             v1 = T.axis.spatial((seq_len + 
T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + 
i0_i1_fused_1 * T.int64(4) + ax0)
                                             v2 = T.axis.spatial(T.int64(4096), 
k_0 * T.int64(256) + ax1_0 * T.int64(8) + ax1_1)
-                                            T.reads(rms_norm260[v0, v1, v2])
-                                            
T.writes(rms_norm260_pad_shared[v0, v1, v2])
-                                            rms_norm260_pad_shared[v0, v1, v2] 
= T.if_then_else(v1 < seq_len, rms_norm260[v0, v1, v2], T.float16(0))
+                                            T.reads(rms_norm130[v0, v1, v2])
+                                            
T.writes(rms_norm130_pad_shared[v0, v1, v2])
+                                            rms_norm130_pad_shared[v0, v1, v2] 
= T.if_then_else(v1 < seq_len, rms_norm130[v0, v1, v2], T.float16(0))
                             for k_1 in range(T.int64(8)):
                                 for ax0 in T.vectorized(T.int64(8)):
-                                    with T.block("lv841_local"):
+                                    with T.block("lv453_local"):
                                         v0 = T.axis.spatial(T.int64(128), k_0 
* T.int64(8) + k_1)
                                         v1 = T.axis.spatial(T.int64(12288), 
i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0)
-                                        T.reads(lv841[v0, v1])
-                                        T.writes(lv841_local[v0, v1])
-                                        lv841_local[v0, v1] = lv841[v0, v1]
+                                        T.reads(lv453[v0, v1])
+                                        T.writes(lv453_local[v0, v1])
+                                        lv453_local[v0, v1] = lv453[v0, v1]
                                 for k_2 in range(T.int64(4)):
                                     for ax0 in T.vectorized(T.int64(8)):
-                                        with T.block("lv840_local"):
+                                        with T.block("lv452_local"):
                                             v0 = T.axis.spatial(T.int64(512), 
k_0 * T.int64(32) + k_1 * T.int64(4) + k_2)
                                             v1 = 
T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0)
-                                            T.reads(lv840[v0, v1])
-                                            T.writes(lv840_local[v0, v1])
-                                            lv840_local[v0, v1] = lv840[v0, v1]
+                                            T.reads(lv452[v0, v1])
+                                            T.writes(lv452_local[v0, v1])
+                                            lv452_local[v0, v1] = lv452[v0, v1]
                                     for k_3 in range(T.int64(8)):
                                         for ax0 in T.vectorized(T.int64(8)):
                                             with T.block("dequantize"):
                                                 v_i0 = 
T.axis.spatial(T.int64(4096), k_0 * T.int64(256) + k_1 * T.int64(32) + k_2 * 
T.int64(8) + k_3)
                                                 v_i1 = 
T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0)
-                                                T.reads(lv840_local[v_i0 // 
T.int64(8), v_i1], lv841_local[v_i0 // T.int64(32), v_i1])
+                                                T.reads(lv452_local[v_i0 // 
T.int64(8), v_i1], lv453_local[v_i0 // T.int64(32), v_i1])
                                                 
T.writes(dequantize_intermediate_intermediate_local[v_i0, v_i1])
-                                                
dequantize_intermediate_intermediate_local[v_i0, v_i1] = (T.Cast("float16", 
T.bitwise_and(T.shift_right(lv840_local[v_i0 // T.int64(8), v_i1], 
T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15))) - 
T.float16(7)) * lv841_local[v_i0 // T.int64(32), v_i1]
+                                                
dequantize_intermediate_intermediate_local[v_i0, v_i1] = (T.Cast("float16", 
T.bitwise_and(T.shift_right(lv452_local[v_i0 // T.int64(8), v_i1], 
T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15))) - 
T.float16(7)) * lv453_local[v_i0 // T.int64(32), v_i1]
                                         for i0_i1_fused_2 in range(T.int64(4)):
                                             for i2_2 in 
T.vectorized(T.int64(8)):
                                                 with T.block("matmul_update"):
@@ -781,19 +788,19 @@ class TestFusedDequantMatmulAndroid(AndroidBeforeAfter):
                                                     v_i1 = 
T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), 
i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + i0_i1_fused_2)
                                                     v_i2 = 
T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2)
                                                     v_k = 
T.axis.reduce(T.int64(4096), k_0 * T.int64(256) + k_1 * T.int64(32) + k_2 * 
T.int64(8) + k_3)
-                                                    
T.reads(matmul_intermediate_pad_local[v_i0, v_i1, v_i2], 
rms_norm260_pad_shared[v_i0, v_i1, v_k], 
dequantize_intermediate_intermediate_local[v_k, v_i2])
+                                                    
T.reads(matmul_intermediate_pad_local[v_i0, v_i1, v_i2], 
rms_norm130_pad_shared[v_i0, v_i1, v_k], 
dequantize_intermediate_intermediate_local[v_k, v_i2])
                                                     
T.writes(matmul_intermediate_pad_local[v_i0, v_i1, v_i2])
-                                                    
matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = 
matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + rms_norm260_pad_shared[v_i0, 
v_i1, v_k] * dequantize_intermediate_intermediate_local[v_k, v_i2]
-                        for ax0 in range(T.int64(4)):
-                            for ax1 in T.vectorized(T.int64(8)):
-                                with T.block("matmul_intermediate_pad"):
-                                    v0 = T.axis.spatial(T.int64(1), T.int64(0))
-                                    v1 = T.axis.spatial(seq_len, i0_i1_fused_0 
* T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0)
-                                    v2 = T.axis.spatial(T.int64(12288), i2_0 * 
T.int64(256) + i2_1 * T.int64(8) + ax1)
-                                    T.where((i0_i1_fused_0 - (seq_len + 
T.int64(31)) // T.int64(32) < T.int64(0) or i0_i1_fused_0 == T.int64(0)) and 
i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0 < seq_len)
-                                    T.reads(matmul_intermediate_pad_local[v0, 
v1, v2])
-                                    T.writes(matmul_intermediate[v0, v1, v2])
-                                    matmul_intermediate[v0, v1, v2] = 
matmul_intermediate_pad_local[v0, v1, v2]
+                                                    
matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = 
matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + rms_norm130_pad_shared[v_i0, 
v_i1, v_k] * dequantize_intermediate_intermediate_local[v_k, v_i2]
+                        for ax0, ax1 in T.grid(T.int64(1), T.int64(4)):
+                            for ax2 in T.vectorized(T.int64(8)):
+                                with T.block("T_add"):
+                                    v_ax0 = T.axis.spatial(T.int64(1), ax0)
+                                    v_ax1 = T.axis.spatial(seq_len, 
i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax1)
+                                    v_ax2 = T.axis.spatial(T.int64(12288), 
i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax2)
+                                    T.where(i0_i1_fused_0 * T.int64(32) + 
i0_i1_fused_1 * T.int64(4) + ax1 < seq_len)
+                                    
T.reads(matmul_intermediate_pad_local[v_ax0, v_ax1, v_ax2], 
transformer_h_0_attn_c_attn_bias3[v_ax2])
+                                    
T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2])
+                                    T_add_intermediate_intermediate[v_ax0, 
v_ax1, v_ax2] = matmul_intermediate_pad_local[v_ax0, v_ax1, v_ax2] + 
transformer_h_0_attn_c_attn_bias3[v_ax2]
     # fmt: on
 
 

Reply via email to