This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new dd3bfb3424 [Unity][DLight] Fix outer_reduction for dynamic shape 
workloads (#15743)
dd3bfb3424 is described below

commit dd3bfb342439c221fe9d6e1626d27099f5095593
Author: Siyuan Feng <[email protected]>
AuthorDate: Sat Sep 16 13:58:02 2023 +0800

    [Unity][DLight] Fix outer_reduction for dynamic shape workloads (#15743)
    
    The PR https://github.com/apache/tvm/pull/15730 introduced the 
outer_reduction
    for adreno gemv. This PR fixes the length issue when applying on dynamic 
workloads.
---
 python/tvm/dlight/gpu/gemv.py        |  2 +-
 tests/python/dlight/test_gpu_gemv.py | 85 ++++++++++++++++++++++++++++++++++++
 2 files changed, 86 insertions(+), 1 deletion(-)

diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py
index 1a3cfb5e26..3544719af0 100644
--- a/python/tvm/dlight/gpu/gemv.py
+++ b/python/tvm/dlight/gpu/gemv.py
@@ -514,7 +514,7 @@ class GEMV(ScheduleRule):
 
         # The config is designed for Adreno
         tx_len = 64
-        vec_len = 4 if len_s > 4096 else 2
+        vec_len = (4 if len_s > 4096 else 2) if isinstance(len_s, int) else 1
         inner_r = 4
 
         bx, tx, vec = sch.split(s, factors=[None, tx_len, vec_len])
diff --git a/tests/python/dlight/test_gpu_gemv.py 
b/tests/python/dlight/test_gpu_gemv.py
index e36d58be05..7f60d5db32 100644
--- a/tests/python/dlight/test_gpu_gemv.py
+++ b/tests/python/dlight/test_gpu_gemv.py
@@ -822,5 +822,90 @@ def test_outer_reduction_adreno():
     tvm.ir.assert_structural_equal(mod["main"], expected)
 
 
+def test_outer_reduction_adreno_dynamic():
+    # fmt: off
+    @T.prim_func(private=True)
+    def before(p_lv612: T.handle, p_lv613: T.handle, lv1607: 
T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0: 
T.handle):
+        T.func_attr({"tir.noalias": T.bool(True)})
+        v = T.int64()
+        lv612 = T.match_buffer(p_lv612, (T.int64(512), v), "uint32")
+        lv613 = T.match_buffer(p_lv613, (T.int64(128), v), "float16")
+        p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), 
T.int64(1), v))
+        # with T.block("root"):
+        p_output0_intermediate_1 = T.alloc_buffer((T.int64(4096), v), 
"float16")
+        var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), v), 
"float16")
+        for i, j in T.grid(T.int64(4096), v):
+            with T.block("decode"):
+                v_i, v_j = T.axis.remap("SS", [i, j])
+                T.reads(lv612[v_i // T.int64(8), v_j], lv613[v_i // 
T.int64(32), v_j])
+                T.writes(p_output0_intermediate_1[v_i, v_j])
+                p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", 
T.bitwise_and(T.shift_right(lv612[v_i // T.int64(8), v_j], T.Cast("uint32", v_i 
% T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv613[v_i // 
T.int64(32), v_j]
+        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), v, T.int64(4096)):
+            with T.block("matmul"):
+                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+                T.reads(lv1607[v_i0, v_i1, v_k], p_output0_intermediate_1[v_k, 
v_i2])
+                T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
+                with T.init():
+                    var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
+                var_matmul_intermediate[v_i0, v_i1, v_i2] = 
var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1607[v_i0, v_i1, v_k] * 
p_output0_intermediate_1[v_k, v_i2]
+        for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), v):
+            with T.block("compute"):
+                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+                T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2])
+                T.writes(p_output0_intermediate[v_i0, v_i1, v_i2])
+                p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", 
var_matmul_intermediate[v_i0, v_i1, v_i2])
+
+    @T.prim_func(private=True)
+    def expected(p_lv612: T.handle, p_lv613: T.handle, lv1607: 
T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0: 
T.handle):
+        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+        v = T.int64()
+        lv612 = T.match_buffer(p_lv612, (T.int64(512), v), "uint32")
+        lv613 = T.match_buffer(p_lv613, (T.int64(128), v), "float16")
+        p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), 
T.int64(1), v))
+        # with T.block("root"):
+        var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), 
T.int64(1), v), "float16", scope="local")
+        lv1607_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), 
"float16", scope="local")
+        for u_fused in T.thread_binding(1, thread="blockIdx.y"):
+            for ax0_fused_0 in T.thread_binding((v + T.int64(63)) // 
T.int64(64), thread="blockIdx.x"):
+                for ax0_fused_1 in T.thread_binding(T.int64(64), 
thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 8, 
"pragma_unroll_explicit": 1}):
+                    for ax0_fused_2_init in T.vectorized(T.int64(1)):
+                        with T.block("matmul_init"):
+                            v0 = T.axis.spatial(v, ax0_fused_0 * T.int64(64) + 
ax0_fused_1 + ax0_fused_2_init)
+                            T.where(ax0_fused_0 * T.int64(64) + ax0_fused_1 + 
ax0_fused_2_init < v)
+                            T.reads()
+                            T.writes(var_matmul_intermediate_local[T.int64(0), 
T.int64(0), v0])
+                            var_matmul_intermediate_local[T.int64(0), 
T.int64(0), v0] = T.float16(0)
+                    for ax1_0_fused_0, ax1_0_fused_1 in T.grid(T.int64(128), 
T.int64(4)):
+                        for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
+                            for ax2 in T.vectorized(T.int64(8)):
+                                with T.block("lv1607_local"):
+                                    v0, v1 = T.axis.remap("SS", [ax0, ax1])
+                                    v2 = T.axis.spatial(T.int64(4096), 
ax1_0_fused_0 * T.int64(32) + ax1_0_fused_1 * T.int64(8) + ax2)
+                                    T.reads(lv1607[v0, v1, v2])
+                                    T.writes(lv1607_local[v0, v1, v2])
+                                    lv1607_local[v0, v1, v2] = lv1607[v0, v1, 
v2]
+                        for ax1_1 in range(T.int64(8)):
+                            for ax0_fused_2 in T.vectorized(T.int64(1)):
+                                with T.block("matmul_update"):
+                                    v0 = T.axis.spatial(v, ax0_fused_0 * 
T.int64(64) + ax0_fused_1 + ax0_fused_2)
+                                    v1 = T.axis.reduce(T.int64(4096), 
ax1_0_fused_0 * T.int64(32) + ax1_0_fused_1 * T.int64(8) + ax1_1)
+                                    T.where(ax0_fused_0 * T.int64(64) + 
ax0_fused_1 + ax0_fused_2 < v)
+                                    
T.reads(var_matmul_intermediate_local[T.int64(0), T.int64(0), v0], 
lv1607_local[T.int64(0), T.int64(0), v1], lv612[v1 // T.int64(8), v0], lv613[v1 
// T.int64(32), v0])
+                                    
T.writes(var_matmul_intermediate_local[T.int64(0), T.int64(0), v0])
+                                    var_matmul_intermediate_local[T.int64(0), 
T.int64(0), v0] = var_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + 
lv1607_local[T.int64(0), T.int64(0), v1] * ((T.Cast("float16", 
T.bitwise_and(T.shift_right(lv612[v1 // T.int64(8), v0], T.Cast("uint32", v1 % 
T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv613[v1 // 
T.int64(32), v0])
+                    with T.block("compute"):
+                        v0 = T.axis.spatial(v, ax0_fused_0 * T.int64(64) + 
ax0_fused_1)
+                        T.where(ax0_fused_0 * T.int64(64) + ax0_fused_1 < v)
+                        T.reads(var_matmul_intermediate_local[T.int64(0), 
T.int64(0), v0])
+                        T.writes(p_output0_intermediate[T.int64(0), 
T.int64(0), v0])
+                        p_output0_intermediate[T.int64(0), T.int64(0), v0] = 
T.Cast("float32", var_matmul_intermediate_local[T.int64(0), T.int64(0), v0])
+    # fmt: on
+
+    mod = tvm.IRModule({"main": before})
+    with Target("opencl", host="llvm -mtriple=aarch64-linux-android"):
+        mod = dl.ApplyDefaultSchedule(dl.gpu.GEMV())(mod)
+        tvm.ir.assert_structural_equal(mod["main"], expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to