This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 07ee13f663 [Unity][Dlight] GeMV rule max_num_threads awareness (#15647)
07ee13f663 is described below

commit 07ee13f663aefe533d93cc126d3cddcdcb9a52a6
Author: Ruihang Lai <[email protected]>
AuthorDate: Thu Aug 31 15:48:29 2023 -0400

    [Unity][Dlight] GeMV rule max_num_threads awareness (#15647)
    
    Prior to this PR, the GeMV dlight rule is not aware of the maximum
    number of threads of the backend target. This PR brings the
    awareness. This is mainly for the purpose of WebGPU where each thread
    block can have 256 threads at most.
---
 python/tvm/dlight/gpu/gemv.py        |   7 +++
 tests/python/dlight/test_gpu_gemv.py | 106 +++++++++++++++++++++++++++++++++++
 2 files changed, 113 insertions(+)

diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py
index a41b80756d..e06a0eb49b 100644
--- a/python/tvm/dlight/gpu/gemv.py
+++ b/python/tvm/dlight/gpu/gemv.py
@@ -463,6 +463,13 @@ class GEMV(ScheduleRule):
 
         if not isinstance(len_S, int):
             TS, TR = 1, 64
+
+        while TS * TR > target.max_num_threads:
+            if TS > 1:
+                TS //= 2
+            else:
+                TR //= 2
+
         TILE_S, TILE_R = (
             1,
             len_c
diff --git a/tests/python/dlight/test_gpu_gemv.py 
b/tests/python/dlight/test_gpu_gemv.py
index 2fe7c06e33..59e5fcc5fa 100644
--- a/tests/python/dlight/test_gpu_gemv.py
+++ b/tests/python/dlight/test_gpu_gemv.py
@@ -183,6 +183,112 @@ class TestGEMV(BaseBeforeAfter):
     # fmt: on
 
 
+def test_decode_gemv_256_threads():
+    # fmt: off
+    @T.prim_func(private=True)
+    def before(lv571: T.Buffer((22016, 512), "uint32"), lv572: 
T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), 
var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")):
+        T.func_attr({"tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        p_output0_intermediate = T.alloc_buffer((22016, 4096), "float16")
+        for i, j in T.grid(22016, 4096):
+            with T.block("decode"):
+                v_i, v_j = T.axis.remap("SS", [i, j])
+                T.reads(lv571[v_i, v_j // 8], lv572[v_i, v_j // 32])
+                T.writes(p_output0_intermediate[v_i, v_j])
+                p_output0_intermediate[v_i, v_j] = (T.Cast("float16", 
T.bitwise_and(T.shift_right(lv571[v_i, v_j // 8], T.Cast("uint32", v_j % 8) * 
T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv572[v_i, v_j // 32]
+        for i0, i1, i2, k in T.grid(1, 1, 22016, 4096):
+            with T.block("NT_matmul"):
+                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+                T.reads(lv1654[v_i0, v_i1, v_k], p_output0_intermediate[v_i2, 
v_k])
+                T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
+                with T.init():
+                    var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
+                var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = 
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv1654[v_i0, v_i1, v_k] * 
p_output0_intermediate[v_i2, v_k]
+
+    @T.prim_func(private=True)
+    def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: 
T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), 
var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")):
+        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        var_NT_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 
22016), "float16", scope="local")
+        var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((8, 1, 1, 
22016), "float16", scope="local")
+        lv571_local = T.alloc_buffer((22016, 512), "uint32", scope="local")
+        lv1654_shared = T.alloc_buffer((1, 1, 4096), "float16", scope="shared")
+        for u_fused_ax0_fused_fused_0 in T.thread_binding(688, 
thread="blockIdx.x"):
+            for u_fused_ax0_fused_fused_1 in T.thread_binding(32, 
thread="threadIdx.y"):
+                for 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in 
T.thread_binding(8, thread="threadIdx.x"):
+                    for ax0, ax1 in T.grid(1, 1):
+                        for ax2_0 in T.serial(4, 
annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
+                            for ax2_1 in T.thread_binding(32, 
thread="threadIdx.y"):
+                                for ax2_2 in T.thread_binding(8, 
thread="threadIdx.x"):
+                                    for ax2_3 in T.vectorized(4):
+                                        with T.block("lv1654_shared"):
+                                            v0, v1 = T.axis.remap("SS", [ax0, 
ax1])
+                                            v2 = T.axis.spatial(4096, ax2_0 * 
1024 + ax2_1 * 32 + ax2_2 * 4 + ax2_3)
+                                            T.reads(lv1654[v0, v1, v2])
+                                            T.writes(lv1654_shared[v0, v1, v2])
+                                            lv1654_shared[v0, v1, v2] = 
lv1654[v0, v1, v2]
+                    for u_fused_ax0_fused_fused_2_init in range(1):
+                        for 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in 
T.vectorized(2):
+                            with T.block("NT_matmul_rf_init"):
+                                
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init)
+                                v0 = T.axis.spatial(22016, 
u_fused_ax0_fused_fused_0 * 32 + u_fused_ax0_fused_fused_1 + 
u_fused_ax0_fused_fused_2_init)
+                                T.reads()
+                                
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 0, 0, v0])
+                                
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 0, 0, v0] = T.float16(0)
+                    for ax1_0_fused_ax1_1_fused_0 in T.serial(64, 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                        for ax0_0, ax1 in T.grid(1, 1):
+                            for ax0_1 in T.vectorized(1):
+                                with T.block("lv571_local"):
+                                    v0 = T.axis.spatial(22016, 
u_fused_ax0_fused_fused_0 * 32 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
+                                    v1 = T.axis.spatial(512, 
ax1_0_fused_ax1_1_fused_0 * 8 + 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
+                                    T.reads(lv571[v0, v1])
+                                    T.writes(lv571_local[v0, v1])
+                                    lv571_local[v0, v1] = lv571[v0, v1]
+                        for u_fused_ax0_fused_fused_2, 
ax1_0_fused_ax1_1_fused_2 in T.grid(1, 4):
+                            for 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(2):
+                                with T.block("NT_matmul_rf_update"):
+                                    
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1)
+                                    v0 = T.axis.spatial(22016, 
u_fused_ax0_fused_fused_0 * 32 + u_fused_ax0_fused_fused_1 + 
u_fused_ax0_fused_fused_2)
+                                    vax1_0_fused_ax1_1_fused_0, 
vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, 
ax1_0_fused_ax1_1_fused_2])
+                                    
T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 0, 0, v0], lv1654_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + 
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + 
vax1_0_fused_ax1_1_fused_2 * 2 + 
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2], 
lv571_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 
4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fus [...]
+                                    
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 0, 0, v0])
+                                    
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 0, 0, v0] = 
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 0, 0, v0] + lv1654_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + 
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + 
vax1_0_fused_ax1_1_fused_2 * 2 + 
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2] * 
((T.Cast("float16 [...]
+            for ax2_fused_0 in T.thread_binding(32, thread="threadIdx.y"):
+                for ax0 in T.thread_binding(8, thread="threadIdx.x"):
+                    for ax2_fused_1_0 in T.serial(1, 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                        for ax2_fused_1_1 in T.vectorized(1):
+                            with T.block("NT_matmul_rf_init"):
+                                
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = 
T.axis.spatial(8, ax0)
+                                v0 = T.axis.spatial(22016, 
u_fused_ax0_fused_fused_0 * 32 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
+                                T.reads()
+                                
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 0, 0, v0])
+                                
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 0, 0, v0] = T.float16(0)
+                            for ax1 in range(2):
+                                with T.block("NT_matmul_rf_update"):
+                                    
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = 
T.axis.remap("SR", [ax0, ax1])
+                                    v0 = T.axis.spatial(22016, 
u_fused_ax0_fused_fused_0 * 32 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
+                                    
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 0, 0, v0], 
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0
 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0])
+                                    
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 0, 0, v0])
+                                    
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 0, 0, v0] = 
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 0, 0, v0] + 
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0
 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]
+            for ax1_fused_1 in range(1):
+                for ax1_fused_0 in T.thread_binding(32, thread="threadIdx.y"):
+                    for ax0 in T.thread_binding(8, thread="threadIdx.x"):
+                        with T.block("NT_matmul"):
+                            
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(8, 
ax0)
+                            v0 = T.axis.spatial(22016, 
u_fused_ax0_fused_fused_0 * 32 + ax1_fused_0 + ax1_fused_1)
+                            
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 0, 0, v0])
+                            T.writes(var_NT_matmul_intermediate[0, 0, v0])
+                            with T.init():
+                                var_NT_matmul_intermediate[0, 0, v0] = 
T.float16(0)
+                            var_NT_matmul_intermediate[0, 0, v0] = 
var_NT_matmul_intermediate[0, 0, v0] + 
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 0, 0, v0]
+    # fmt: on
+
+    mod = tvm.IRModule({"main": before})
+    with Target("apple/m1-gpu-restricted"):
+        mod = dl.ApplyDefaultSchedule(dl.gpu.GEMV())(mod)
+    tvm.ir.assert_structural_equal(mod["main"], expected)
+
+
 def test_decode_gemv1():
     # fmt: off
 

Reply via email to