This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b01de08715 [DLight] Fix a corner case for reduction rule (#16848)
b01de08715 is described below
commit b01de087157e448c3454766393a057d9565e7d73
Author: Siyuan Feng <[email protected]>
AuthorDate: Fri Apr 5 17:12:54 2024 +0800
[DLight] Fix a corner case for reduction rule (#16848)
The current rule will fail when the output shape is only one element,
because of missing `preserve_unit_loops`. This PR fixes it and adding a
test case.
---
python/tvm/dlight/gpu/reduction.py | 2 +-
tests/python/dlight/test_gpu_reduction.py | 93 ++++++++++++++++++++++++++-----
2 files changed, 79 insertions(+), 16 deletions(-)
diff --git a/python/tvm/dlight/gpu/reduction.py
b/python/tvm/dlight/gpu/reduction.py
index 651e09dc52..4cc142ab16 100644
--- a/python/tvm/dlight/gpu/reduction.py
+++ b/python/tvm/dlight/gpu/reduction.py
@@ -217,7 +217,7 @@ class Reduction(GPUScheduleRule):
# Schedule epilogue
if epilogue_info is not None:
epilogue = epilogue_info.block_rv
- sch.reverse_compute_at(epilogue, bx)
+ sch.reverse_compute_at(epilogue, bx, preserve_unit_loops=True)
if is_broadcast_epilogue(sch, block, epilogue):
sch.set_scope(block, 0, "shared")
_, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name
diff --git a/tests/python/dlight/test_gpu_reduction.py
b/tests/python/dlight/test_gpu_reduction.py
index def124a9b2..1ce57eb53d 100644
--- a/tests/python/dlight/test_gpu_reduction.py
+++ b/tests/python/dlight/test_gpu_reduction.py
@@ -377,11 +377,12 @@ def test_decode_gemv_sigmoid():
with T.init():
C_local[0, 0, v0] = T.float16(0)
C_local[0, 0, v0] = C_local[0, 0, v0] +
C_rf_local[vax1_0_fused_1, 0, 0, v0]
- with T.block("sigmoid"):
- v0 = T.axis.spatial(4096, ax0_fused)
- T.reads(C_local[0, 0, v0])
- T.writes(D[0, 0, v0])
- D[0, 0, v0] = T.sigmoid(C_local[0, 0, v0])
+ for ax0 in range(1):
+ with T.block("sigmoid"):
+ v0 = T.axis.spatial(4096, ax0_fused + ax0)
+ T.reads(C_local[0, 0, v0])
+ T.writes(D[0, 0, v0])
+ D[0, 0, v0] = T.sigmoid(C_local[0, 0, v0])
# fmt: on
@@ -465,11 +466,12 @@ def test_decode_gemv_1_fp32():
with T.init():
C_fp32_local[0, 0, v0] = T.float32(0)
C_fp32_local[0, 0, v0] = C_fp32_local[0, 0,
v0] + C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0]
- with T.block("cast"):
- v0 = T.axis.spatial(4096, ax0_fused)
- T.reads(C_fp32_local[0, 0, v0])
- T.writes(C[0, 0, v0])
- C[0, 0, v0] = T.Cast("float16", C_fp32_local[0, 0, v0])
+ for ax0 in range(1):
+ with T.block("cast"):
+ v0 = T.axis.spatial(4096, ax0_fused + ax0)
+ T.reads(C_fp32_local[0, 0, v0])
+ T.writes(C[0, 0, v0])
+ C[0, 0, v0] = T.Cast("float16", C_fp32_local[0, 0,
v0])
# fmt: on
@@ -760,11 +762,12 @@ def test_reduction_inner_no_broadcasting():
with T.init():
temp_local_local[v0] = T.float32(0)
temp_local_local[v0] = temp_local_local[v0] +
temp_local_rf_local[vax1_fused_1, v0]
- with T.block("add"):
- v0 = T.axis.spatial(256, ax0_fused)
- T.reads(temp_local_local[v0])
- T.writes(B[v0])
- B[v0] = temp_local_local[v0] + T.float32(1)
+ for ax0 in range(1):
+ with T.block("add"):
+ v0 = T.axis.spatial(256, ax0_fused + ax0)
+ T.reads(temp_local_local[v0])
+ T.writes(B[v0])
+ B[v0] = temp_local_local[v0] + T.float32(1)
# fmt: on
target = Target("nvidia/geforce-rtx-3090-ti")
@@ -1089,5 +1092,65 @@ def test_gemv_dyn_shape_epilogue():
assert_structural_equal(mod, Expected)
+def test_gemv_output_one_element():
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @T.prim_func(private=True)
+ def main(A: T.Buffer((T.int64(1), T.int64(2048)), "float16"), weight:
T.Buffer((T.int64(1), T.int64(2048)), "float16"), out: T.Buffer((T.int64(1),
T.int64(1)), "float16")):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1)),
"float16")
+ for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(2048)):
+ with T.block("NT_matmul"):
+ v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
+ with T.init():
+ NT_matmul_intermediate[v_i0, v_i1] = T.float16(0)
+ NT_matmul_intermediate[v_i0, v_i1] =
NT_matmul_intermediate[v_i0, v_i1] + A[v_i0, v_k] * weight[v_i1, v_k]
+ for i0, i1 in T.grid(T.int64(1), T.int64(1)):
+ with T.block("compute"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ out[v_i0, v_i1] = T.sigmoid(NT_matmul_intermediate[v_i0,
v_i1])
+
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def main(A: T.Buffer((T.int64(1), T.int64(2048)), "float16"), weight:
T.Buffer((T.int64(1), T.int64(2048)), "float16"), out: T.Buffer((T.int64(1),
T.int64(1)), "float16")):
+ T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+ NT_matmul_intermediate_shared = T.alloc_buffer((T.int64(1),
T.int64(1)), "float16", scope="shared")
+ NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(1024),
T.int64(1), T.int64(1)), "float16", scope="local")
+ for ax0_fused in T.thread_binding(T.int64(1), thread="blockIdx.x"):
+ for ax1_fused_1 in T.thread_binding(T.int64(1024),
thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256,
"pragma_unroll_explicit": 1}):
+ with T.block("NT_matmul_rf_init"):
+ vax1_fused_1 = T.axis.spatial(T.int64(1024),
ax1_fused_1)
+ v0 = T.axis.spatial(T.int64(1), T.int64(0))
+ NT_matmul_intermediate_rf_local[vax1_fused_1,
T.int64(0), T.int64(0)] = T.float16(0)
+ for ax1_fused_0, u in T.grid(T.int64(2), 1):
+ with T.block("NT_matmul_rf_update"):
+ vax1_fused_1 = T.axis.spatial(T.int64(1024),
ax1_fused_1)
+ v0 = T.axis.spatial(T.int64(1), T.int64(0))
+ vax1_fused_0 = T.axis.reduce(T.int64(2),
ax1_fused_0)
+ NT_matmul_intermediate_rf_local[vax1_fused_1,
T.int64(0), T.int64(0)] = NT_matmul_intermediate_rf_local[vax1_fused_1,
T.int64(0), T.int64(0)] + A[T.int64(0), vax1_fused_0 * T.int64(1024) +
vax1_fused_1] * weight[T.int64(0), vax1_fused_0 * T.int64(1024) + vax1_fused_1]
+ for ax1_fused in range(T.int64(1)):
+ for ax0 in T.thread_binding(T.int64(1024),
thread="threadIdx.x"):
+ with T.block("NT_matmul"):
+ vax1_fused_1 = T.axis.reduce(T.int64(1024), ax0)
+ v0 = T.axis.spatial(T.int64(1), T.int64(0))
+ with T.init():
+ NT_matmul_intermediate_shared[T.int64(0),
T.int64(0)] = T.float16(0)
+ NT_matmul_intermediate_shared[T.int64(0),
T.int64(0)] = NT_matmul_intermediate_shared[T.int64(0), T.int64(0)] +
NT_matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), T.int64(0)]
+ for ax0_fused_0 in range(T.int64(1)):
+ for ax0_fused_1 in T.thread_binding(T.int64(1024),
thread="threadIdx.x"):
+ with T.block("compute"):
+ v0 = T.axis.spatial(T.int64(1), T.int64(0))
+ T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1
< T.int64(1))
+ out[T.int64(0), T.int64(0)] =
T.sigmoid(NT_matmul_intermediate_shared[T.int64(0), T.int64(0)])
+ # fmt: on
+
+ with Target("nvidia/geforce-rtx-3090-ti"):
+ mod = dl.ApplyDefaultSchedule(dl.gpu.Reduction())(Before) # pylint:
disable=not-callable
+ assert_structural_equal(mod, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()