This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new cf401bc6b4 [Unity][Dlight] Fix decode-GeMV rule when spatial-inner 
without broadcasting (#15330)
cf401bc6b4 is described below

commit cf401bc6b4776033b09548187d88aa2297026cac
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Jul 16 07:09:07 2023 -0700

    [Unity][Dlight] Fix decode-GeMV rule when spatial-inner without 
broadcasting (#15330)
    
    This PR fixes a bug of the previous decode-GeMV dlight scheduling.
    
    Previously, when the inner dimension of the largest tensor is spatial,
    in the end the fused epilogue block was not bound to any thread axis,
    which is wrong and will generate wrong GPU code with wrong numerical
    results. That is because after doing reverse-compute-at of the epilogue
    block, there are at lease one remaining spatial axis, and such axis
    is supposed to be bound to threadIdx.
    
    This PR fixes this issue, and add three test cases which can cover
    both the reduction-inner and spatial-inner cases with or without
    broadcasting.
---
 python/tvm/dlight/gpu/decode_gemv.py        |   6 +-
 tests/python/dlight/test_gpu_decode_gemv.py | 230 +++++++++++++++++++++++++++-
 2 files changed, 228 insertions(+), 8 deletions(-)

diff --git a/python/tvm/dlight/gpu/decode_gemv.py 
b/python/tvm/dlight/gpu/decode_gemv.py
index afcfdb3020..1aa5d68fc5 100644
--- a/python/tvm/dlight/gpu/decode_gemv.py
+++ b/python/tvm/dlight/gpu/decode_gemv.py
@@ -233,7 +233,11 @@ class DecodeGEMV(ScheduleRule):
                 _, *s = sch.get_loops(epilogue)  # pylint: disable=invalid-name
                 _, tx, ty = sch.split(sch.fuse(*s), factors=[None, len_tx, 
len_ty])
                 sch.bind(tx, "threadIdx.x")
-                sch.bind(ty, "threadIdx.x")
+                sch.bind(ty, "threadIdx.y")
             else:
+                # The epilogue is element-wise without broadcasting.
+                # Thus the remaining spatial part should be bind to tx.
                 sch.set_scope(block, 0, "local")
+                _, *s = sch.get_loops(epilogue)  # pylint: disable=invalid-name
+                sch.bind(sch.fuse(*s), "threadIdx.x")
         # pylint: enable=invalid-name
diff --git a/tests/python/dlight/test_gpu_decode_gemv.py 
b/tests/python/dlight/test_gpu_decode_gemv.py
index 7b19e6b7f8..971f5f4d09 100644
--- a/tests/python/dlight/test_gpu_decode_gemv.py
+++ b/tests/python/dlight/test_gpu_decode_gemv.py
@@ -15,6 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: 
disable=missing-docstring,line-too-long,invalid-name,too-few-public-methods,too-many-locals
+
+import tvm.testing
 from tvm import dlight as dl
 from tvm.ir import assert_structural_equal
 from tvm.script import ir as I
@@ -489,11 +491,225 @@ def test_reduction_no_spatial():
     assert_structural_equal(mod, After)
 
 
+def test_spatial_inner_no_broadcasting():
+    # fmt: off
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def main(lv575: T.Buffer((1376, 4096), "uint32"), lv576: 
T.Buffer((344, 4096), "float16"), lv574: T.Buffer((1, 1, 11008), "float16"), 
lv570: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 
1, 4096), "float16")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            p_output0_intermediate_1 = T.alloc_buffer((11008, 4096), "float16")
+            var_matmul_intermediate = T.alloc_buffer((1, 1, 4096), "float16")
+            for i, j in T.grid(11008, 4096):
+                with T.block("decode"):
+                    v_i, v_j = T.axis.remap("SS", [i, j])
+                    T.reads(lv575[v_i // 8, v_j], lv576[v_i // 32, v_j])
+                    T.writes(p_output0_intermediate_1[v_i, v_j])
+                    p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", 
T.bitwise_and(T.shift_right(lv575[v_i // 8, v_j], T.Cast("uint32", v_i % 8) * 
T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv576[v_i // 32, v_j]
+            for i0, i1, i2, k in T.grid(1, 1, 4096, 11008):
+                with T.block("matmul"):
+                    v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, 
k])
+                    T.reads(lv574[v_i0, v_i1, v_k], 
p_output0_intermediate_1[v_k, v_i2])
+                    T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
+                    with T.init():
+                        var_matmul_intermediate[v_i0, v_i1, v_i2] = 
T.float16(0)
+                    var_matmul_intermediate[v_i0, v_i1, v_i2] = 
var_matmul_intermediate[v_i0, v_i1, v_i2] + lv574[v_i0, v_i1, v_k] * 
p_output0_intermediate_1[v_k, v_i2]
+            for ax0, ax1, ax2 in T.grid(1, 1, 4096):
+                with T.block("T_add"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(lv570[v_ax0, v_ax1, v_ax2], 
var_matmul_intermediate[v_ax0, v_ax1, v_ax2])
+                    T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
+                    p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv570[v_ax0, 
v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2]
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def main(lv575: T.Buffer((1376, 4096), "uint32"), lv576: 
T.Buffer((344, 4096), "float16"), lv574: T.Buffer((1, 1, 11008), "float16"), 
lv570: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 
1, 4096), "float16")):
+            T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+            var_matmul_intermediate_local = T.alloc_buffer((1, 1, 4096), 
"float16", scope="local")
+            var_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 
4096), "float16", scope="local")
+            for ax0_fused_0 in T.thread_binding(256, thread="blockIdx.x"):
+                for ax0_fused_1 in T.thread_binding(16, thread="threadIdx.x"):
+                    for ax1_0_fused_1 in T.thread_binding(16, 
thread="threadIdx.y"):
+                        with T.block("matmul_rf_init"):
+                            vax1_0_fused_1 = T.axis.spatial(16, ax1_0_fused_1)
+                            v0 = T.axis.spatial(4096, ax0_fused_0 * 16 + 
ax0_fused_1)
+                            T.reads()
+                            
T.writes(var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0])
+                            var_matmul_intermediate_rf_local[vax1_0_fused_1, 
0, 0, v0] = T.float16(0)
+                        for ax1_0_fused_0, ax1_1 in T.grid(86, 8):
+                            with T.block("matmul_rf_update"):
+                                vax1_0_fused_1 = T.axis.spatial(16, 
ax1_0_fused_1)
+                                v0 = T.axis.spatial(4096, ax0_fused_0 * 16 + 
ax0_fused_1)
+                                vax1_0_fused_0, vax1_1 = T.axis.remap("RR", 
[ax1_0_fused_0, ax1_1])
+                                
T.reads(var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0], lv574[0, 0, 
vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 + vax1_1], lv575[(vax1_0_fused_0 * 
128 + vax1_0_fused_1 * 8 + vax1_1) // 8, v0], lv576[(vax1_0_fused_0 * 128 + 
vax1_0_fused_1 * 8 + vax1_1) // 32, v0])
+                                
T.writes(var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0])
+                                
var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = 
var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] + lv574[0, 0, 
vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 + vax1_1] * ((T.Cast("float16", 
T.bitwise_and(T.shift_right(lv575[(vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 + 
vax1_1) // 8, v0], T.Cast("uint32", (vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 
+ vax1_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * 
lv576[(vax1_0_fused_0 * 128 +  [...]
+                for ax1_fused in T.thread_binding(16, thread="threadIdx.x"):
+                    for ax0 in T.thread_binding(16, thread="threadIdx.y"):
+                        with T.block("matmul"):
+                            vax1_0_fused_1 = T.axis.reduce(16, ax0)
+                            v0 = T.axis.spatial(4096, ax0_fused_0 * 16 + 
ax1_fused)
+                            
T.reads(var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0])
+                            T.writes(var_matmul_intermediate_local[0, 0, v0])
+                            with T.init():
+                                var_matmul_intermediate_local[0, 0, v0] = 
T.float16(0)
+                            var_matmul_intermediate_local[0, 0, v0] = 
var_matmul_intermediate_local[0, 0, v0] + 
var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0]
+                for ax0_fused in T.thread_binding(16, thread="threadIdx.x"):
+                    with T.block("T_add"):
+                        v0 = T.axis.spatial(4096, ax0_fused_0 * 16 + ax0_fused)
+                        T.reads(lv570[0, 0, v0], 
var_matmul_intermediate_local[0, 0, v0])
+                        T.writes(p_output0_intermediate[0, 0, v0])
+                        p_output0_intermediate[0, 0, v0] = lv570[0, 0, v0] + 
var_matmul_intermediate_local[0, 0, v0]
+    # fmt: on
+
+    target = Target("nvidia/geforce-rtx-3090-ti")
+    with target:
+        mod = dl.ApplyDefaultSchedule(dl.gpu.DecodeGEMV())(Module)  # pylint: 
disable=not-callable
+    assert_structural_equal(mod, Expected)
+
+
+def test_spatial_inner_broadcasting():
+    # fmt: off
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256, 256), 
"float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            temp_local = T.alloc_buffer((256,))
+            for j in T.serial(256):
+                for k in T.serial(256):
+                    with T.block("sum"):
+                        vj, vk = T.axis.remap("SR", [j, k])
+                        T.reads(A[vk, vj])
+                        T.writes(temp_local[vj])
+                        with T.init():
+                            temp_local[vj] = T.float32(0)
+                        temp_local[vj] = temp_local[vj] + A[vk, vj]
+            for i, j in T.grid(256, 256):
+                with T.block("add"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    T.reads(temp_local[vj])
+                    T.writes(B[vi, vj])
+                    B[vi, vj] = A[vi, vj] + temp_local[vj]
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256, 256), 
"float32")):
+            T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+            temp_local_shared = T.alloc_buffer((256,), scope="shared")
+            temp_local_rf_local = T.alloc_buffer((16, 256), scope="local")
+            for ax0_fused_0 in T.thread_binding(16, thread="blockIdx.x"):
+                for ax0_fused_1 in T.thread_binding(16, thread="threadIdx.x"):
+                    for ax1_fused_1 in T.thread_binding(16, 
thread="threadIdx.y"):
+                        with T.block("sum_rf_init"):
+                            vax1_fused_1 = T.axis.spatial(16, ax1_fused_1)
+                            v0 = T.axis.spatial(256, ax0_fused_0 * 16 + 
ax0_fused_1)
+                            T.reads()
+                            T.writes(temp_local_rf_local[vax1_fused_1, v0])
+                            temp_local_rf_local[vax1_fused_1, v0] = 
T.float32(0)
+                        for ax1_fused_0, u in T.grid(16, 1):
+                            with T.block("sum_rf_update"):
+                                vax1_fused_1 = T.axis.spatial(16, ax1_fused_1)
+                                v0 = T.axis.spatial(256, ax0_fused_0 * 16 + 
ax0_fused_1)
+                                vax1_fused_0 = T.axis.reduce(16, ax1_fused_0)
+                                T.reads(temp_local_rf_local[vax1_fused_1, v0], 
A[vax1_fused_0 * 16 + vax1_fused_1, v0])
+                                T.writes(temp_local_rf_local[vax1_fused_1, v0])
+                                temp_local_rf_local[vax1_fused_1, v0] = 
temp_local_rf_local[vax1_fused_1, v0] + A[vax1_fused_0 * 16 + vax1_fused_1, v0]
+                for ax1_fused in T.thread_binding(16, thread="threadIdx.x"):
+                    for ax0 in T.thread_binding(16, thread="threadIdx.y"):
+                        with T.block("sum"):
+                            vax1_fused_1 = T.axis.reduce(16, ax0)
+                            v0 = T.axis.spatial(256, ax0_fused_0 * 16 + 
ax1_fused)
+                            T.reads(temp_local_rf_local[vax1_fused_1, v0])
+                            T.writes(temp_local_shared[v0])
+                            with T.init():
+                                temp_local_shared[v0] = T.float32(0)
+                            temp_local_shared[v0] = temp_local_shared[v0] + 
temp_local_rf_local[vax1_fused_1, v0]
+                for ax0_ax1_fused_0 in range(16):
+                    for ax0_ax1_fused_1 in T.thread_binding(16, 
thread="threadIdx.x"):
+                        for ax0_ax1_fused_2 in T.thread_binding(16, 
thread="threadIdx.y"):
+                            with T.block("add"):
+                                v0 = T.axis.spatial(256, (ax0_ax1_fused_0 * 
256 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2) // 16)
+                                v1 = T.axis.spatial(256, ax0_fused_0 * 16 + 
(ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2) % 16)
+                                T.reads(temp_local_shared[v1])
+                                T.writes(B[v0, v1])
+                                B[v0, v1] = A[v0, v1] + temp_local_shared[v1]
+    # fmt: on
+
+    target = Target("nvidia/geforce-rtx-3090-ti")
+    with target:
+        mod = dl.ApplyDefaultSchedule(dl.gpu.DecodeGEMV())(Module)  # pylint: 
disable=not-callable
+    assert_structural_equal(mod, Expected)
+
+
+def test_reduction_inner_no_broadcasting():
+    # fmt: off
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256,), 
"float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            temp_local = T.alloc_buffer((256,))
+            for i in T.serial(256):
+                for k in T.serial(256):
+                    with T.block("sum"):
+                        vi, vk = T.axis.remap("SR", [i, k])
+                        T.reads(A[vi, vk])
+                        T.writes(temp_local[vi])
+                        with T.init():
+                            temp_local[vi] = T.float32(0)
+                        temp_local[vi] = temp_local[vi] + A[vi, vk]
+            for i in T.grid(256):
+                with T.block("add"):
+                    vi = T.axis.remap("S", [i])
+                    T.reads(temp_local[vi])
+                    T.writes(B[vi,])
+                    B[vi] = temp_local[vi] + T.float32(1)
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256,), 
"float32")):
+            T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            temp_local_local = T.alloc_buffer((256,), scope="local")
+            temp_local_rf_local = T.alloc_buffer((256, 256), scope="local")
+            for ax0_fused in T.thread_binding(256, thread="blockIdx.x"):
+                for ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x", 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                    with T.block("sum_rf_init"):
+                        vax1_fused_1, v0 = T.axis.remap("SS", [ax1_fused_1, 
ax0_fused])
+                        T.reads()
+                        T.writes(temp_local_rf_local[vax1_fused_1, v0])
+                        temp_local_rf_local[vax1_fused_1, v0] = T.float32(0)
+                    for ax1_fused_0, u in T.grid(1, 1):
+                        with T.block("sum_rf_update"):
+                            vax1_fused_1, v0, vax1_fused_0 = 
T.axis.remap("SSR", [ax1_fused_1, ax0_fused, ax1_fused_0])
+                            T.reads(temp_local_rf_local[vax1_fused_1, v0], 
A[v0, vax1_fused_0 * 256 + vax1_fused_1])
+                            T.writes(temp_local_rf_local[vax1_fused_1, v0])
+                            temp_local_rf_local[vax1_fused_1, v0] = 
temp_local_rf_local[vax1_fused_1, v0] + A[v0, vax1_fused_0 * 256 + vax1_fused_1]
+                for ax1_fused in range(1):
+                    for ax0 in T.thread_binding(256, thread="threadIdx.x"):
+                        with T.block("sum"):
+                            vax1_fused_1, v0 = T.axis.remap("RS", [ax0, 
ax0_fused])
+                            T.reads(temp_local_rf_local[vax1_fused_1, v0])
+                            T.writes(temp_local_local[v0])
+                            with T.init():
+                                temp_local_local[v0] = T.float32(0)
+                            temp_local_local[v0] = temp_local_local[v0] + 
temp_local_rf_local[vax1_fused_1, v0]
+                with T.block("add"):
+                    v0 = T.axis.spatial(256, ax0_fused)
+                    T.reads(temp_local_local[v0])
+                    T.writes(B[v0])
+                    B[v0] = temp_local_local[v0] + T.float32(1)
+    # fmt: on
+
+    target = Target("nvidia/geforce-rtx-3090-ti")
+    with target:
+        mod = dl.ApplyDefaultSchedule(dl.gpu.DecodeGEMV())(Module)  # pylint: 
disable=not-callable
+    assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
-    test_decode_gemv_1()
-    test_decode_gemv_2()
-    test_decode_gemv_3()
-    test_decode_gemv_4()
-    test_decode_gemv_sigmoid()
-    test_decode_gemv_1_fp32()
-    test_reduction_no_spatial()
+    tvm.testing.main()

Reply via email to