This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 1e1ff66fb3 [Unity][Dlight] Fix DecodeGeMV rule for spatial-inner with
grouping (#15340)
1e1ff66fb3 is described below
commit 1e1ff66fb3ba7072d644dd005eed9da61271e1df
Author: Ruihang Lai <[email protected]>
AuthorDate: Mon Jul 17 17:39:54 2023 -0700
[Unity][Dlight] Fix DecodeGeMV rule for spatial-inner with grouping (#15340)
This PR fixes a bug of DecodeGeMV dlight rule when the innermost
tensor dimension is spatial with `unroll_factor` (for example, the
grouping used in group quantization).
Prior to this PR, a reduction loop that is bound to threadIdx was
reordered to reside outside a split spatial loop, which prevents the
TIR LowerCrossThreadReduction pass to successfully apply due to some
safety-guard requirement.
This PR fixes this issue by not reordering the split spatial loop
after the reduction loop, so that the pass can be applied.
Note that we can do this as the order of thread-binding loops does
not matter.
---
python/tvm/dlight/gpu/decode_gemv.py | 3 +--
tests/python/dlight/test_gpu_decode_gemv.py | 4 ++--
2 files changed, 3 insertions(+), 4 deletions(-)
diff --git a/python/tvm/dlight/gpu/decode_gemv.py
b/python/tvm/dlight/gpu/decode_gemv.py
index 1aa5d68fc5..5566f3248c 100644
--- a/python/tvm/dlight/gpu/decode_gemv.py
+++ b/python/tvm/dlight/gpu/decode_gemv.py
@@ -220,8 +220,7 @@ class DecodeGEMV(ScheduleRule):
s = sch.fuse(*s)
sch.reorder(s, r)
if unroll_spatial_factor:
- s, inner = sch.split(s, factors=[None, unroll_spatial_factor])
- sch.reorder(s, r, inner)
+ s, _ = sch.split(s, factors=[None, unroll_spatial_factor])
sch.bind(s, "threadIdx.x")
sch.bind(r, "threadIdx.y")
# Schedule epilogue
diff --git a/tests/python/dlight/test_gpu_decode_gemv.py
b/tests/python/dlight/test_gpu_decode_gemv.py
index 971f5f4d09..d037ffa3ee 100644
--- a/tests/python/dlight/test_gpu_decode_gemv.py
+++ b/tests/python/dlight/test_gpu_decode_gemv.py
@@ -259,8 +259,8 @@ def test_decode_gemv_4():
vk_fused_0 = T.axis.reduce(256, k_fused_0)
C_rf_local[vk_fused_1, 0, 0, v_i2] =
C_rf_local[vk_fused_1, 0, 0, v_i2] + V[0, 0, vk_fused_0 * 16 + vk_fused_1] *
((T.Cast("float16", T.bitwise_and(T.shift_right(W[vk_fused_0 * 16 + vk_fused_1,
v_i2 // 8], T.Cast("uint32", v_i2 % 8) * T.uint32(4)), T.uint32(15))) -
T.float16(7)) * S[vk_fused_0 * 16 + vk_fused_1, v_i2 // 32])
for ax1_ax2_ax3_fused_0 in T.thread_binding(16,
thread="threadIdx.x"):
- for ax0_fused in T.thread_binding(16,
thread="threadIdx.y"):
- for ax1_ax2_ax3_fused_1 in range(8):
+ for ax1_ax2_ax3_fused_1 in range(8):
+ for ax0_fused in T.thread_binding(16,
thread="threadIdx.y"):
with T.block("matmul"):
vk_fused_1 = T.axis.reduce(16, ax0_fused)
v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused_0
* 128 + ax1_ax2_ax3_fused_0 * 8 + ax1_ax2_ax3_fused_1)