This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 1e1ff66fb3 [Unity][Dlight] Fix DecodeGeMV rule for spatial-inner with 
grouping (#15340)
1e1ff66fb3 is described below

commit 1e1ff66fb3ba7072d644dd005eed9da61271e1df
Author: Ruihang Lai <[email protected]>
AuthorDate: Mon Jul 17 17:39:54 2023 -0700

    [Unity][Dlight] Fix DecodeGeMV rule for spatial-inner with grouping (#15340)
    
    This PR fixes a bug of DecodeGeMV dlight rule when the innermost
    tensor dimension is spatial with `unroll_factor` (for example, the
    grouping used in group quantization).
    
    Prior to this PR, a reduction loop that is bound to threadIdx was
    reordered to reside outside a split spatial loop, which prevents the
    TIR LowerCrossThreadReduction pass to successfully apply due to some
    safety-guard requirement.
    
    This PR fixes this issue by not reordering the split spatial loop
    after the reduction loop, so that the pass can be applied.
    Note that we can do this as the order of thread-binding loops does
    not matter.
---
 python/tvm/dlight/gpu/decode_gemv.py        | 3 +--
 tests/python/dlight/test_gpu_decode_gemv.py | 4 ++--
 2 files changed, 3 insertions(+), 4 deletions(-)

diff --git a/python/tvm/dlight/gpu/decode_gemv.py 
b/python/tvm/dlight/gpu/decode_gemv.py
index 1aa5d68fc5..5566f3248c 100644
--- a/python/tvm/dlight/gpu/decode_gemv.py
+++ b/python/tvm/dlight/gpu/decode_gemv.py
@@ -220,8 +220,7 @@ class DecodeGEMV(ScheduleRule):
         s = sch.fuse(*s)
         sch.reorder(s, r)
         if unroll_spatial_factor:
-            s, inner = sch.split(s, factors=[None, unroll_spatial_factor])
-            sch.reorder(s, r, inner)
+            s, _ = sch.split(s, factors=[None, unroll_spatial_factor])
         sch.bind(s, "threadIdx.x")
         sch.bind(r, "threadIdx.y")
         # Schedule epilogue
diff --git a/tests/python/dlight/test_gpu_decode_gemv.py 
b/tests/python/dlight/test_gpu_decode_gemv.py
index 971f5f4d09..d037ffa3ee 100644
--- a/tests/python/dlight/test_gpu_decode_gemv.py
+++ b/tests/python/dlight/test_gpu_decode_gemv.py
@@ -259,8 +259,8 @@ def test_decode_gemv_4():
                                 vk_fused_0 = T.axis.reduce(256, k_fused_0)
                                 C_rf_local[vk_fused_1, 0, 0, v_i2] = 
C_rf_local[vk_fused_1, 0, 0, v_i2] + V[0, 0, vk_fused_0 * 16 + vk_fused_1] * 
((T.Cast("float16", T.bitwise_and(T.shift_right(W[vk_fused_0 * 16 + vk_fused_1, 
v_i2 // 8], T.Cast("uint32", v_i2 % 8) * T.uint32(4)), T.uint32(15))) - 
T.float16(7)) * S[vk_fused_0 * 16 + vk_fused_1, v_i2 // 32])
                 for ax1_ax2_ax3_fused_0 in T.thread_binding(16, 
thread="threadIdx.x"):
-                    for ax0_fused in T.thread_binding(16, 
thread="threadIdx.y"):
-                        for ax1_ax2_ax3_fused_1 in range(8):
+                    for ax1_ax2_ax3_fused_1 in range(8):
+                        for ax0_fused in T.thread_binding(16, 
thread="threadIdx.y"):
                             with T.block("matmul"):
                                 vk_fused_1 = T.axis.reduce(16, ax0_fused)
                                 v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused_0 
* 128 + ax1_ax2_ax3_fused_0 * 8 + ax1_ax2_ax3_fused_1)

Reply via email to