This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 7dd248b0be [Unity][Dlight] Choose perfect spatial factor in reduction 
rule (#16101)
7dd248b0be is described below

commit 7dd248b0be9ff57cd3d853d9a679833896020cb0
Author: Wuwei Lin <[email protected]>
AuthorDate: Fri Nov 10 05:05:47 2023 -0800

    [Unity][Dlight] Choose perfect spatial factor in reduction rule (#16101)
---
 python/tvm/dlight/gpu/reduction.py        | 10 ++++-
 tests/python/dlight/test_gpu_reduction.py | 68 ++++++++++++++++++++++++++++++-
 2 files changed, 76 insertions(+), 2 deletions(-)

diff --git a/python/tvm/dlight/gpu/reduction.py 
b/python/tvm/dlight/gpu/reduction.py
index ddc0be3861..3e2e5ee532 100644
--- a/python/tvm/dlight/gpu/reduction.py
+++ b/python/tvm/dlight/gpu/reduction.py
@@ -99,7 +99,7 @@ class Reduction(ScheduleRule):
         if is_inner_reduction:
             self._sch_inner_reduction(sch, target, block, c_factor, epilogue)
         else:
-            self._sch_inner_spatial(sch, target, block, c_factor, epilogue)
+            self._sch_inner_spatial(sch, target, block, block_info, c_factor, 
epilogue)
         return sch
 
     def _normalize(  # pylint: disable=too-many-branches
@@ -198,12 +198,20 @@ class Reduction(ScheduleRule):
         sch: tir.Schedule,
         _: Target,
         block: tir.schedule.BlockRV,
+        block_info: BlockInfo,
         unroll_spatial_factor: Optional[int],
         epilogue_info: Optional[BlockInfo],
     ):
         # pylint: disable=invalid-name
         s, r, _ = sch.get_loops(block)
         len_tx, len_ty = 16, 16
+        s_factor = [i.dom for i in block_info.iters if i.kind == "S"][-1]
+        # get perfect spatial factor, spatial factor should be divide the 
innermost spatial loop so
+        # that the block after r_factor and be reversed compute at the 
original scope
+        while len_tx > 1:
+            if s_factor % len_tx == 0:
+                break
+            len_tx -= 1
         _, _ = sch.split(s, factors=[None, len_tx])
         _, ty = sch.split(r, factors=[None, len_ty])
         # Schedule the RF block
diff --git a/tests/python/dlight/test_gpu_reduction.py 
b/tests/python/dlight/test_gpu_reduction.py
index 6f4c6df25e..6198a2eb72 100644
--- a/tests/python/dlight/test_gpu_reduction.py
+++ b/tests/python/dlight/test_gpu_reduction.py
@@ -713,7 +713,6 @@ def test_reduction_inner_no_broadcasting():
 
 
 def test_reduction_inner_no_broadcasting2():
-
     # fmt: off
     @I.ir_module
     class Module:
@@ -795,5 +794,72 @@ def test_reduction_inner_no_broadcasting2():
     assert_structural_equal(mod, Expected)
 
 
+def test_reduction_inner_spatial_choose_perfect_factor():
+    # fmt: off
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def main(var_A: T.handle, var_B: T.handle, matmul: 
T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(100)), "float16")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            n = T.int64()
+            A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), 
n), "float16")
+            B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, 
T.int64(100)), "float16")
+            # with T.block("root"):
+            for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), 
T.int64(1), T.int64(100), n):
+                with T.block("matmul"):
+                    v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, 
i1, i2, i3, k])
+                    T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3])
+                    T.writes(matmul[v_i0, v_i1, v_i2, v_i3])
+                    with T.init():
+                        matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0)
+                    matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, 
v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3]
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def main(var_A: T.handle, var_B: T.handle, matmul: 
T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(100)), "float16")):
+            T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+            n = T.int64()
+            A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), 
n), "float16")
+            B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, 
T.int64(100)), "float16")
+            # with T.block("root"):
+            matmul_rf_local = T.alloc_buffer((T.int64(16), T.int64(1), 
T.int64(32), T.int64(1), T.int64(100)), "float16", scope="local")
+            for ax0_ax1_fused_0 in T.thread_binding(T.int64(320), 
thread="blockIdx.x"):
+                for ax0_ax1_fused_1 in T.thread_binding(T.int64(10), 
thread="threadIdx.x"):
+                    for ax2_fused_1 in T.thread_binding(T.int64(16), 
thread="threadIdx.y"):
+                        with T.block("matmul_rf_init"):
+                            vax2_fused_1 = T.axis.spatial(T.int64(16), 
ax2_fused_1)
+                            v0 = T.axis.spatial(T.int64(32), (ax0_ax1_fused_0 
* T.int64(10) + ax0_ax1_fused_1) // T.int64(100))
+                            v1 = T.axis.spatial(T.int64(100), (ax0_ax1_fused_0 
* T.int64(10) + ax0_ax1_fused_1) % T.int64(100))
+                            T.reads()
+                            T.writes(matmul_rf_local[vax2_fused_1, T.int64(0), 
v0, T.int64(0), v1])
+                            matmul_rf_local[vax2_fused_1, T.int64(0), v0, 
T.int64(0), v1] = T.float16(0)
+                        for ax2_fused_0, u in T.grid((n + T.int64(15)) // 
T.int64(16), 1):
+                            with T.block("matmul_rf_update"):
+                                vax2_fused_1 = T.axis.spatial(T.int64(16), 
ax2_fused_1)
+                                v0 = T.axis.spatial(T.int64(32), 
(ax0_ax1_fused_0 * T.int64(10) + ax0_ax1_fused_1) // T.int64(100))
+                                v1 = T.axis.spatial(T.int64(100), 
(ax0_ax1_fused_0 * T.int64(10) + ax0_ax1_fused_1) % T.int64(100))
+                                vax2_fused_0 = T.axis.reduce((n + T.int64(15)) 
// T.int64(16), ax2_fused_0)
+                                T.where(ax2_fused_0 * T.int64(16) + 
ax2_fused_1 < n)
+                                T.reads(matmul_rf_local[vax2_fused_1, 
T.int64(0), v0, T.int64(0), v1], A[T.int64(0), v0, T.int64(0), vax2_fused_0 * 
T.int64(16) + vax2_fused_1], B[T.int64(0), v0, vax2_fused_0 * T.int64(16) + 
vax2_fused_1, v1])
+                                T.writes(matmul_rf_local[vax2_fused_1, 
T.int64(0), v0, T.int64(0), v1])
+                                matmul_rf_local[vax2_fused_1, T.int64(0), v0, 
T.int64(0), v1] = matmul_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] 
+ A[T.int64(0), v0, T.int64(0), vax2_fused_0 * T.int64(16) + vax2_fused_1] * 
B[T.int64(0), v0, vax2_fused_0 * T.int64(16) + vax2_fused_1, v1]
+                for ax1_ax2_fused in T.thread_binding(T.int64(10), 
thread="threadIdx.x"):
+                    for ax0 in T.thread_binding(T.int64(16), 
thread="threadIdx.y"):
+                        with T.block("matmul"):
+                            vax2_fused_1 = T.axis.reduce(T.int64(16), ax0)
+                            v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused_0 
// T.int64(10))
+                            v1 = T.axis.spatial(T.int64(100), ax0_ax1_fused_0 
% T.int64(10) * T.int64(10) + ax1_ax2_fused)
+                            T.reads(matmul_rf_local[vax2_fused_1, T.int64(0), 
v0, T.int64(0), v1])
+                            T.writes(matmul[T.int64(0), v0, T.int64(0), v1])
+                            with T.init():
+                                matmul[T.int64(0), v0, T.int64(0), v1] = 
T.float16(0)
+                            matmul[T.int64(0), v0, T.int64(0), v1] = 
matmul[T.int64(0), v0, T.int64(0), v1] + matmul_rf_local[vax2_fused_1, 
T.int64(0), v0, T.int64(0), v1]
+    # fmt: on
+
+    with Target("nvidia/geforce-rtx-3090-ti"):
+        mod = dl.ApplyDefaultSchedule(dl.gpu.Reduction())(Module)  # pylint: 
disable=not-callable
+    assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to