This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 7dd248b0be [Unity][Dlight] Choose perfect spatial factor in reduction
rule (#16101)
7dd248b0be is described below
commit 7dd248b0be9ff57cd3d853d9a679833896020cb0
Author: Wuwei Lin <[email protected]>
AuthorDate: Fri Nov 10 05:05:47 2023 -0800
[Unity][Dlight] Choose perfect spatial factor in reduction rule (#16101)
---
python/tvm/dlight/gpu/reduction.py | 10 ++++-
tests/python/dlight/test_gpu_reduction.py | 68 ++++++++++++++++++++++++++++++-
2 files changed, 76 insertions(+), 2 deletions(-)
diff --git a/python/tvm/dlight/gpu/reduction.py
b/python/tvm/dlight/gpu/reduction.py
index ddc0be3861..3e2e5ee532 100644
--- a/python/tvm/dlight/gpu/reduction.py
+++ b/python/tvm/dlight/gpu/reduction.py
@@ -99,7 +99,7 @@ class Reduction(ScheduleRule):
if is_inner_reduction:
self._sch_inner_reduction(sch, target, block, c_factor, epilogue)
else:
- self._sch_inner_spatial(sch, target, block, c_factor, epilogue)
+ self._sch_inner_spatial(sch, target, block, block_info, c_factor,
epilogue)
return sch
def _normalize( # pylint: disable=too-many-branches
@@ -198,12 +198,20 @@ class Reduction(ScheduleRule):
sch: tir.Schedule,
_: Target,
block: tir.schedule.BlockRV,
+ block_info: BlockInfo,
unroll_spatial_factor: Optional[int],
epilogue_info: Optional[BlockInfo],
):
# pylint: disable=invalid-name
s, r, _ = sch.get_loops(block)
len_tx, len_ty = 16, 16
+ s_factor = [i.dom for i in block_info.iters if i.kind == "S"][-1]
+ # get perfect spatial factor, spatial factor should be divide the
innermost spatial loop so
+ # that the block after r_factor and be reversed compute at the
original scope
+ while len_tx > 1:
+ if s_factor % len_tx == 0:
+ break
+ len_tx -= 1
_, _ = sch.split(s, factors=[None, len_tx])
_, ty = sch.split(r, factors=[None, len_ty])
# Schedule the RF block
diff --git a/tests/python/dlight/test_gpu_reduction.py
b/tests/python/dlight/test_gpu_reduction.py
index 6f4c6df25e..6198a2eb72 100644
--- a/tests/python/dlight/test_gpu_reduction.py
+++ b/tests/python/dlight/test_gpu_reduction.py
@@ -713,7 +713,6 @@ def test_reduction_inner_no_broadcasting():
def test_reduction_inner_no_broadcasting2():
-
# fmt: off
@I.ir_module
class Module:
@@ -795,5 +794,72 @@ def test_reduction_inner_no_broadcasting2():
assert_structural_equal(mod, Expected)
+def test_reduction_inner_spatial_choose_perfect_factor():
+ # fmt: off
+ @I.ir_module
+ class Module:
+ @T.prim_func
+ def main(var_A: T.handle, var_B: T.handle, matmul:
T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(100)), "float16")):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ n = T.int64()
+ A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1),
n), "float16")
+ B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n,
T.int64(100)), "float16")
+ # with T.block("root"):
+ for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32),
T.int64(1), T.int64(100), n):
+ with T.block("matmul"):
+ v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0,
i1, i2, i3, k])
+ T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3])
+ T.writes(matmul[v_i0, v_i1, v_i2, v_i3])
+ with T.init():
+ matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0)
+ matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2,
v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3]
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def main(var_A: T.handle, var_B: T.handle, matmul:
T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(100)), "float16")):
+ T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+ n = T.int64()
+ A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1),
n), "float16")
+ B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n,
T.int64(100)), "float16")
+ # with T.block("root"):
+ matmul_rf_local = T.alloc_buffer((T.int64(16), T.int64(1),
T.int64(32), T.int64(1), T.int64(100)), "float16", scope="local")
+ for ax0_ax1_fused_0 in T.thread_binding(T.int64(320),
thread="blockIdx.x"):
+ for ax0_ax1_fused_1 in T.thread_binding(T.int64(10),
thread="threadIdx.x"):
+ for ax2_fused_1 in T.thread_binding(T.int64(16),
thread="threadIdx.y"):
+ with T.block("matmul_rf_init"):
+ vax2_fused_1 = T.axis.spatial(T.int64(16),
ax2_fused_1)
+ v0 = T.axis.spatial(T.int64(32), (ax0_ax1_fused_0
* T.int64(10) + ax0_ax1_fused_1) // T.int64(100))
+ v1 = T.axis.spatial(T.int64(100), (ax0_ax1_fused_0
* T.int64(10) + ax0_ax1_fused_1) % T.int64(100))
+ T.reads()
+ T.writes(matmul_rf_local[vax2_fused_1, T.int64(0),
v0, T.int64(0), v1])
+ matmul_rf_local[vax2_fused_1, T.int64(0), v0,
T.int64(0), v1] = T.float16(0)
+ for ax2_fused_0, u in T.grid((n + T.int64(15)) //
T.int64(16), 1):
+ with T.block("matmul_rf_update"):
+ vax2_fused_1 = T.axis.spatial(T.int64(16),
ax2_fused_1)
+ v0 = T.axis.spatial(T.int64(32),
(ax0_ax1_fused_0 * T.int64(10) + ax0_ax1_fused_1) // T.int64(100))
+ v1 = T.axis.spatial(T.int64(100),
(ax0_ax1_fused_0 * T.int64(10) + ax0_ax1_fused_1) % T.int64(100))
+ vax2_fused_0 = T.axis.reduce((n + T.int64(15))
// T.int64(16), ax2_fused_0)
+ T.where(ax2_fused_0 * T.int64(16) +
ax2_fused_1 < n)
+ T.reads(matmul_rf_local[vax2_fused_1,
T.int64(0), v0, T.int64(0), v1], A[T.int64(0), v0, T.int64(0), vax2_fused_0 *
T.int64(16) + vax2_fused_1], B[T.int64(0), v0, vax2_fused_0 * T.int64(16) +
vax2_fused_1, v1])
+ T.writes(matmul_rf_local[vax2_fused_1,
T.int64(0), v0, T.int64(0), v1])
+ matmul_rf_local[vax2_fused_1, T.int64(0), v0,
T.int64(0), v1] = matmul_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1]
+ A[T.int64(0), v0, T.int64(0), vax2_fused_0 * T.int64(16) + vax2_fused_1] *
B[T.int64(0), v0, vax2_fused_0 * T.int64(16) + vax2_fused_1, v1]
+ for ax1_ax2_fused in T.thread_binding(T.int64(10),
thread="threadIdx.x"):
+ for ax0 in T.thread_binding(T.int64(16),
thread="threadIdx.y"):
+ with T.block("matmul"):
+ vax2_fused_1 = T.axis.reduce(T.int64(16), ax0)
+ v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused_0
// T.int64(10))
+ v1 = T.axis.spatial(T.int64(100), ax0_ax1_fused_0
% T.int64(10) * T.int64(10) + ax1_ax2_fused)
+ T.reads(matmul_rf_local[vax2_fused_1, T.int64(0),
v0, T.int64(0), v1])
+ T.writes(matmul[T.int64(0), v0, T.int64(0), v1])
+ with T.init():
+ matmul[T.int64(0), v0, T.int64(0), v1] =
T.float16(0)
+ matmul[T.int64(0), v0, T.int64(0), v1] =
matmul[T.int64(0), v0, T.int64(0), v1] + matmul_rf_local[vax2_fused_1,
T.int64(0), v0, T.int64(0), v1]
+ # fmt: on
+
+ with Target("nvidia/geforce-rtx-3090-ti"):
+ mod = dl.ApplyDefaultSchedule(dl.gpu.Reduction())(Module) # pylint:
disable=not-callable
+ assert_structural_equal(mod, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()