This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b01de08715 [DLight] Fix a corner case for reduction rule (#16848)
b01de08715 is described below

commit b01de087157e448c3454766393a057d9565e7d73
Author: Siyuan Feng <[email protected]>
AuthorDate: Fri Apr 5 17:12:54 2024 +0800

    [DLight] Fix a corner case for reduction rule (#16848)
    
    The current rule will fail when the output shape is only one element,
    because of missing `preserve_unit_loops`. This PR fixes it and adding a
    test case.
---
 python/tvm/dlight/gpu/reduction.py        |  2 +-
 tests/python/dlight/test_gpu_reduction.py | 93 ++++++++++++++++++++++++++-----
 2 files changed, 79 insertions(+), 16 deletions(-)

diff --git a/python/tvm/dlight/gpu/reduction.py 
b/python/tvm/dlight/gpu/reduction.py
index 651e09dc52..4cc142ab16 100644
--- a/python/tvm/dlight/gpu/reduction.py
+++ b/python/tvm/dlight/gpu/reduction.py
@@ -217,7 +217,7 @@ class Reduction(GPUScheduleRule):
         # Schedule epilogue
         if epilogue_info is not None:
             epilogue = epilogue_info.block_rv
-            sch.reverse_compute_at(epilogue, bx)
+            sch.reverse_compute_at(epilogue, bx, preserve_unit_loops=True)
             if is_broadcast_epilogue(sch, block, epilogue):
                 sch.set_scope(block, 0, "shared")
                 _, *s = sch.get_loops(epilogue)  # pylint: disable=invalid-name
diff --git a/tests/python/dlight/test_gpu_reduction.py 
b/tests/python/dlight/test_gpu_reduction.py
index def124a9b2..1ce57eb53d 100644
--- a/tests/python/dlight/test_gpu_reduction.py
+++ b/tests/python/dlight/test_gpu_reduction.py
@@ -377,11 +377,12 @@ def test_decode_gemv_sigmoid():
                                 with T.init():
                                     C_local[0, 0, v0] = T.float16(0)
                                 C_local[0, 0, v0] = C_local[0, 0, v0] + 
C_rf_local[vax1_0_fused_1, 0, 0, v0]
-                    with T.block("sigmoid"):
-                        v0 = T.axis.spatial(4096, ax0_fused)
-                        T.reads(C_local[0, 0, v0])
-                        T.writes(D[0, 0, v0])
-                        D[0, 0, v0] = T.sigmoid(C_local[0, 0, v0])
+                    for ax0 in range(1):
+                        with T.block("sigmoid"):
+                            v0 = T.axis.spatial(4096, ax0_fused + ax0)
+                            T.reads(C_local[0, 0, v0])
+                            T.writes(D[0, 0, v0])
+                            D[0, 0, v0] = T.sigmoid(C_local[0, 0, v0])
 
     # fmt: on
 
@@ -465,11 +466,12 @@ def test_decode_gemv_1_fp32():
                                 with T.init():
                                     C_fp32_local[0, 0, v0] = T.float32(0)
                                 C_fp32_local[0, 0, v0] = C_fp32_local[0, 0, 
v0] + C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0]
-                    with T.block("cast"):
-                        v0 = T.axis.spatial(4096, ax0_fused)
-                        T.reads(C_fp32_local[0, 0, v0])
-                        T.writes(C[0, 0, v0])
-                        C[0, 0, v0] = T.Cast("float16", C_fp32_local[0, 0, v0])
+                    for ax0 in range(1):
+                        with T.block("cast"):
+                            v0 = T.axis.spatial(4096, ax0_fused + ax0)
+                            T.reads(C_fp32_local[0, 0, v0])
+                            T.writes(C[0, 0, v0])
+                            C[0, 0, v0] = T.Cast("float16", C_fp32_local[0, 0, 
v0])
 
     # fmt: on
 
@@ -760,11 +762,12 @@ def test_reduction_inner_no_broadcasting():
                             with T.init():
                                 temp_local_local[v0] = T.float32(0)
                             temp_local_local[v0] = temp_local_local[v0] + 
temp_local_rf_local[vax1_fused_1, v0]
-                with T.block("add"):
-                    v0 = T.axis.spatial(256, ax0_fused)
-                    T.reads(temp_local_local[v0])
-                    T.writes(B[v0])
-                    B[v0] = temp_local_local[v0] + T.float32(1)
+                for ax0 in range(1):
+                    with T.block("add"):
+                        v0 = T.axis.spatial(256, ax0_fused + ax0)
+                        T.reads(temp_local_local[v0])
+                        T.writes(B[v0])
+                        B[v0] = temp_local_local[v0] + T.float32(1)
     # fmt: on
 
     target = Target("nvidia/geforce-rtx-3090-ti")
@@ -1089,5 +1092,65 @@ def test_gemv_dyn_shape_epilogue():
     assert_structural_equal(mod, Expected)
 
 
+def test_gemv_output_one_element():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @T.prim_func(private=True)
+        def main(A: T.Buffer((T.int64(1), T.int64(2048)), "float16"), weight: 
T.Buffer((T.int64(1), T.int64(2048)), "float16"), out: T.Buffer((T.int64(1), 
T.int64(1)), "float16")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1)), 
"float16")
+            for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(2048)):
+                with T.block("NT_matmul"):
+                    v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
+                    with T.init():
+                        NT_matmul_intermediate[v_i0, v_i1] = T.float16(0)
+                    NT_matmul_intermediate[v_i0, v_i1] = 
NT_matmul_intermediate[v_i0, v_i1] + A[v_i0, v_k] * weight[v_i1, v_k]
+            for i0, i1 in T.grid(T.int64(1), T.int64(1)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    out[v_i0, v_i1] = T.sigmoid(NT_matmul_intermediate[v_i0, 
v_i1])
+
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def main(A: T.Buffer((T.int64(1), T.int64(2048)), "float16"), weight: 
T.Buffer((T.int64(1), T.int64(2048)), "float16"), out: T.Buffer((T.int64(1), 
T.int64(1)), "float16")):
+            T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+            NT_matmul_intermediate_shared = T.alloc_buffer((T.int64(1), 
T.int64(1)), "float16", scope="shared")
+            NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(1024), 
T.int64(1), T.int64(1)), "float16", scope="local")
+            for ax0_fused in T.thread_binding(T.int64(1), thread="blockIdx.x"):
+                for ax1_fused_1 in T.thread_binding(T.int64(1024), 
thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, 
"pragma_unroll_explicit": 1}):
+                    with T.block("NT_matmul_rf_init"):
+                        vax1_fused_1 = T.axis.spatial(T.int64(1024), 
ax1_fused_1)
+                        v0 = T.axis.spatial(T.int64(1), T.int64(0))
+                        NT_matmul_intermediate_rf_local[vax1_fused_1, 
T.int64(0), T.int64(0)] = T.float16(0)
+                    for ax1_fused_0, u in T.grid(T.int64(2), 1):
+                        with T.block("NT_matmul_rf_update"):
+                            vax1_fused_1 = T.axis.spatial(T.int64(1024), 
ax1_fused_1)
+                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
+                            vax1_fused_0 = T.axis.reduce(T.int64(2), 
ax1_fused_0)
+                            NT_matmul_intermediate_rf_local[vax1_fused_1, 
T.int64(0), T.int64(0)] = NT_matmul_intermediate_rf_local[vax1_fused_1, 
T.int64(0), T.int64(0)] + A[T.int64(0), vax1_fused_0 * T.int64(1024) + 
vax1_fused_1] * weight[T.int64(0), vax1_fused_0 * T.int64(1024) + vax1_fused_1]
+                for ax1_fused in range(T.int64(1)):
+                    for ax0 in T.thread_binding(T.int64(1024), 
thread="threadIdx.x"):
+                        with T.block("NT_matmul"):
+                            vax1_fused_1 = T.axis.reduce(T.int64(1024), ax0)
+                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
+                            with T.init():
+                                NT_matmul_intermediate_shared[T.int64(0), 
T.int64(0)] = T.float16(0)
+                            NT_matmul_intermediate_shared[T.int64(0), 
T.int64(0)] = NT_matmul_intermediate_shared[T.int64(0), T.int64(0)] + 
NT_matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), T.int64(0)]
+                for ax0_fused_0 in range(T.int64(1)):
+                    for ax0_fused_1 in T.thread_binding(T.int64(1024), 
thread="threadIdx.x"):
+                        with T.block("compute"):
+                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
+                            T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 
< T.int64(1))
+                            out[T.int64(0), T.int64(0)] = 
T.sigmoid(NT_matmul_intermediate_shared[T.int64(0), T.int64(0)])
+    # fmt: on
+
+    with Target("nvidia/geforce-rtx-3090-ti"):
+        mod = dl.ApplyDefaultSchedule(dl.gpu.Reduction())(Before)  # pylint: 
disable=not-callable
+    assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to