This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 209971a62e [DLIGHT][GPU] Improved gemv outer fallback schedule (#16973)
209971a62e is described below

commit 209971a62edf4a6ea6c628ef8399e45e926e727c
Author: krishnaraj36 <[email protected]>
AuthorDate: Tue May 21 14:24:53 2024 +0530

    [DLIGHT][GPU] Improved gemv outer fallback schedule (#16973)
    
    * [DLIGHT][GPU] Improved gemv outer fallback schedule
    
    Improved the gemv outer fallback schedules. It improved
    few gemv kernel by 20%.
    
    * Fix lint error
    
    * Fix the gemv schedule params for dynamic vocab_size kernel
---
 python/tvm/dlight/gpu/gemv.py        |  39 ++++++++----
 tests/python/dlight/test_gpu_gemv.py | 113 +++++++++++++++++++----------------
 2 files changed, 91 insertions(+), 61 deletions(-)

diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py
index cbef6235c0..da6a4ef834 100644
--- a/python/tvm/dlight/gpu/gemv.py
+++ b/python/tvm/dlight/gpu/gemv.py
@@ -463,6 +463,8 @@ class GEMV(GPUScheduleRule):
                     TS, TR = 4, 64
                 else:
                     TS, TR = 16, 32
+            else:
+                TS, TR = 1, 64
         elif target.kind.name == "metal":
             # Note that the following tile size is tuned on M2 Ultra for 7B
             TAG_S, TAG_R = "threadIdx.x", "threadIdx.y"
@@ -476,6 +478,8 @@ class GEMV(GPUScheduleRule):
                     TS, TR = 4, 16
                 else:
                     TS, TR = 2, 64
+            else:
+                TS, TR = 1, 64
         elif target.kind.name == "rocm":
             VEC_C = 4
             # TODO: set LOAD_V_SHARED = False for now
@@ -489,13 +493,15 @@ class GEMV(GPUScheduleRule):
                     TS, TR = 1, 128
                 else:
                     TS, TR = 8, 64
+            else:
+                TS, TR = 1, 64
         elif target.kind.name == "opencl" and "android" in str(target.host):
             TAG_S, TAG_R = "threadIdx.x", "threadIdx.y"
             VEC_C = 8
             LOAD_V_SHARED = False
             LOAD_V_VEC = -1
             UNROLL = 8
-            TS, TR = 2, 64
+            TS, TR = 2, 32
         elif target.kind.name == "vulkan":
             VEC_C = 4
             LOAD_V_SHARED = True
@@ -506,6 +512,8 @@ class GEMV(GPUScheduleRule):
                     TS, TR = 4, 32
                 else:
                     TS, TR = 16, 32
+            else:
+                TS, TR = 1, 64
         elif target.kind.name == "opencl" and "mali" in str(target.attrs):
             VEC_C = 8
             LOAD_V_SHARED = False
@@ -519,9 +527,6 @@ class GEMV(GPUScheduleRule):
             UNROLL = 64
             TS, TR = 1, 64
 
-        if not isinstance(len_S, int):
-            TS, TR = 1, 64
-
         while TS * TR > target.max_num_threads:
             if TS > 1:
                 TS //= 2
@@ -709,7 +714,11 @@ class GEMV(GPUScheduleRule):
         if not isinstance(len_r, int):
             return None
 
-        if isinstance(len_s, int) and len_s > 32000:
+        if not isinstance(len_s, int):
+            TS, TR = 256, 1
+            LOAD_V_SHARED = True
+
+        if isinstance(len_s, int) and len_s > 96000:
             return None
 
         _, TILE_R = (
@@ -754,7 +763,8 @@ class GEMV(GPUScheduleRule):
         len_s = get_extent(sch, s)
 
         # The config is designed for Adreno
-        tx_len = 64
+        LOAD_V_SHARED = 1
+        tx_len = 128
         vec_len = (4 if len_s > 4096 else 2) if isinstance(len_s, int) else 1
         inner_r = 4
 
@@ -768,16 +778,23 @@ class GEMV(GPUScheduleRule):
         sch.annotate(tx, ann_key="pragma_auto_unroll_max_step", ann_val=8)
         sch.annotate(tx, ann_key="pragma_unroll_explicit", ann_val=1)
 
-        cache_v = sch.cache_read(block, vector_input_buffers[0], "local")
-        sch.compute_at(cache_v, r1, preserve_unit_loops=True)
-        sch.vectorize(sch.get_loops(cache_v)[-1])
+        if LOAD_V_SHARED:
+            V_shared = sch.cache_read(block, vector_input_buffers[0], 
storage_scope="shared")
+            sch.compute_at(V_shared, bx, preserve_unit_loops=True)
+            l = sch.get_loops(block=V_shared)[-1]
+            _, tx, vec_r = sch.split(l, factors=[None, tx_len, 8], 
preserve_unit_iters=True)
+            sch.bind(tx, "threadIdx.x")
+            sch.vectorize(vec_r)
 
         sch.vectorize(vec)
 
         # Schedule epilogue
         if epilogue_info is not None:
-            sch.reverse_compute_at(epilogue_info.block_rv, tx)
-
+            sch.reverse_compute_at(epilogue_info.block_rv, bx, 
preserve_unit_loops=True)
+            ts_tile_s = sch.get_loops(epilogue_info.block_rv)[-1]
+            ts, vec = sch.split(ts_tile_s, factors=[tx_len, vec_len], 
preserve_unit_iters=True)
+            sch.bind(ts, "threadIdx.x")
+            sch.vectorize(vec)
             sch.set_scope(block, 0, "local")
 
         sch.decompose_reduction(block, r0)
diff --git a/tests/python/dlight/test_gpu_gemv.py 
b/tests/python/dlight/test_gpu_gemv.py
index 4aae617654..0f7b6f45ae 100644
--- a/tests/python/dlight/test_gpu_gemv.py
+++ b/tests/python/dlight/test_gpu_gemv.py
@@ -1106,82 +1106,95 @@ def test_outer_reduction_adreno_dynamic():
         p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), 
T.int64(1), v))
         # with T.block("root"):
         var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), 
T.int64(1), v), "float16", scope="local")
-        var_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(32), 
T.int64(1), T.int64(1), v), "float16", scope="local")
-        var_matmul_intermediate_rf_local_1 = T.alloc_buffer((T.int64(4), 
T.int64(1), T.int64(1), v), "float16", scope="local")
+        var_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(8), 
T.int64(1), T.int64(1), v), "float16", scope="local")
+        var_matmul_intermediate_rf_local_1 = T.alloc_buffer((T.int64(1), 
T.int64(1), T.int64(1), v), "float16", scope="local")
         lv613_local = T.alloc_buffer((T.int64(128), v), "float16", 
scope="local")
         lv612_local = T.alloc_buffer((T.int64(512), v), "uint32", 
scope="local")
-        for u_fused_ax0_fused_fused_0 in T.thread_binding((v + T.int64(63)) // 
T.int64(64), thread="blockIdx.x"):
-            for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(64), 
thread="threadIdx.x"):
-                for 
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init in 
T.thread_binding(T.int64(4), thread="threadIdx.y"):
+        lv1607_shared = T.alloc_buffer((T.int64(1), T.int64(1), 
T.int64(4096)), "float16", scope="shared")
+        for u_fused_ax0_fused_fused_0 in T.thread_binding((v + T.int64(255)) 
// T.int64(256), thread="blockIdx.x"):
+            for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
+                for 
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init in 
T.thread_binding(T.int64(1), thread="threadIdx.y"):
                     for 
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init in 
T.vectorized(T.int64(8)):
                         with T.block("matmul_rf_init"):
-                            
vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused = 
T.axis.spatial(T.int64(32), 
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init * T.int64(8) + 
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init)
-                            v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * 
T.int64(64) + u_fused_ax0_fused_fused_1)
-                            T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + 
u_fused_ax0_fused_fused_1 < v)
+                            
vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused = 
T.axis.spatial(T.int64(8), 
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init * T.int64(8) + 
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init)
+                            v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * 
T.int64(256) + u_fused_ax0_fused_fused_1)
+                            T.where(u_fused_ax0_fused_fused_0 * T.int64(256) + 
u_fused_ax0_fused_fused_1 < v)
                             T.reads()
                             
T.writes(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused,
 T.int64(0), T.int64(0), v0])
                             
var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused,
 T.int64(0), T.int64(0), v0] = T.float16(0)
-                for 
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 in 
T.thread_binding(T.int64(4), thread="threadIdx.y"):
-                    for ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_1 
in T.grid(T.int64(32), T.int64(1)):
-                        for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
-                            with T.block("lv613_local"):
-                                v0 = T.axis.spatial(T.int64(128), 
ax1_0_fused_ax1_1_fused_0 * T.int64(4) + 
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 + ax0)
-                                v1 = T.axis.spatial(v, 
u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1 + ax1)
-                                T.where(u_fused_ax0_fused_fused_0 * 
T.int64(64) + u_fused_ax0_fused_fused_1 < v)
-                                T.reads(lv613[v0, v1])
-                                T.writes(lv613_local[v0, v1])
-                                lv613_local[v0, v1] = lv613[v0, v1]
-                        for ax1_0_fused_ax1_1_fused_3 in range(T.int64(4)):
+                for 
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 in 
T.thread_binding(T.int64(1), thread="threadIdx.y"):
+                    for ax1_0_fused_ax1_1_fused_0 in range(T.int64(128)):
+                        for ax0, ax1, ax2_0, ax2_1 in T.grid(T.int64(1), 
T.int64(1), T.int64(1), T.int64(1)):
+                            for ax2_2 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
+                                for ax2_3 in T.thread_binding(T.int64(1), 
thread="threadIdx.y"):
+                                    for ax2_4 in T.vectorized(T.int64(4)):
+                                        with T.block("lv1607_shared"):
+                                            v0, v1 = T.axis.remap("SS", [ax0, 
ax1])
+                                            v2 = T.axis.spatial(T.int64(4096), 
ax1_0_fused_ax1_1_fused_0 * T.int64(32) + (ax2_0 * T.int64(1024) + ax2_1 * 
T.int64(1024) + ax2_2 * T.int64(4) + ax2_3 * T.int64(4) + ax2_4))
+                                            T.where(((ax2_0 + ax2_1) * 
T.int64(256) + ax2_2 + ax2_3) * T.int64(4) + ax2_4 < T.int64(32))
+                                            T.reads(lv1607[v0, v1, v2])
+                                            T.writes(lv1607_shared[v0, v1, v2])
+                                            lv1607_shared[v0, v1, v2] = 
lv1607[v0, v1, v2]
+                        for ax1_0_fused_ax1_1_fused_1 in range(T.int64(1)):
                             for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
-                                with T.block("lv612_local"):
-                                    v0 = T.axis.spatial(T.int64(512), 
ax1_0_fused_ax1_1_fused_0 * T.int64(16) + 
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * T.int64(4) + 
ax1_0_fused_ax1_1_fused_3 + ax0)
-                                    v1 = T.axis.spatial(v, 
u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1 + ax1)
-                                    T.where(u_fused_ax0_fused_fused_0 * 
T.int64(64) + u_fused_ax0_fused_fused_1 < v)
-                                    T.reads(lv612[v0, v1])
-                                    T.writes(lv612_local[v0, v1])
-                                    lv612_local[v0, v1] = lv612[v0, v1]
-                            for 
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1 in 
T.vectorized(T.int64(8)):
-                                with T.block("matmul_rf_update"):
-                                    
vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused = 
T.axis.spatial(T.int64(32), 
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * T.int64(8) + 
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1)
-                                    v0 = T.axis.spatial(v, 
u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1)
-                                    vax1_0_fused_ax1_1_fused_0, 
vax1_0_fused_ax1_1_fused_1, vax1_0_fused_ax1_1_fused_3 = T.axis.remap("RRR", 
[ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_1, 
ax1_0_fused_ax1_1_fused_3])
-                                    T.where(u_fused_ax0_fused_fused_0 * 
T.int64(64) + u_fused_ax0_fused_fused_1 < v)
-                                    
T.reads(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused,
 T.int64(0), T.int64(0), v0], lv1607[T.int64(0), T.int64(0), 
vax1_0_fused_ax1_1_fused_0 * T.int64(128) + vax1_0_fused_ax1_1_fused_1 * 
T.int64(128) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // 
T.int64(8) * T.int64(32) + vax1_0_fused_ax1_1_fused_3 * T.int64(8) + 
vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused % T.int64(8)], lv 
[...]
-                                    
T.writes(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused,
 T.int64(0), T.int64(0), v0])
-                                    
var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused,
 T.int64(0), T.int64(0), v0] = 
var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused,
 T.int64(0), T.int64(0), v0] + lv1607[T.int64(0), T.int64(0), 
vax1_0_fused_ax1_1_fused_0 * T.int64(128) + vax1_0_fused_ax1_1_fused_1 * 
T.int64(128) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // 
T.int64(8) * T.int64(32) + va [...]
-            for ax2 in T.thread_binding(T.int64(64), thread="threadIdx.x"):
-                for ax0 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
+                                with T.block("lv613_local"):
+                                    v0 = T.axis.spatial(T.int64(128), 
ax1_0_fused_ax1_1_fused_0 + ax0)
+                                    v1 = T.axis.spatial(v, 
u_fused_ax0_fused_fused_0 * T.int64(256) + u_fused_ax0_fused_fused_1 + ax1)
+                                    T.where(u_fused_ax0_fused_fused_0 * 
T.int64(256) + u_fused_ax0_fused_fused_1 < v)
+                                    T.reads(lv613[v0, v1])
+                                    T.writes(lv613_local[v0, v1])
+                                    lv613_local[v0, v1] = lv613[v0, v1]
+                            for ax1_0_fused_ax1_1_fused_3 in range(T.int64(4)):
+                                for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
+                                    with T.block("lv612_local"):
+                                        v0 = T.axis.spatial(T.int64(512), 
ax1_0_fused_ax1_1_fused_0 * T.int64(4) + ax1_0_fused_ax1_1_fused_3 + ax0)
+                                        v1 = T.axis.spatial(v, 
u_fused_ax0_fused_fused_0 * T.int64(256) + u_fused_ax0_fused_fused_1 + ax1)
+                                        T.where(u_fused_ax0_fused_fused_0 * 
T.int64(256) + u_fused_ax0_fused_fused_1 < v)
+                                        T.reads(lv612[v0, v1])
+                                        T.writes(lv612_local[v0, v1])
+                                        lv612_local[v0, v1] = lv612[v0, v1]
+                                for 
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1 in 
T.vectorized(T.int64(8)):
+                                    with T.block("matmul_rf_update"):
+                                        
vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused = 
T.axis.spatial(T.int64(8), 
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * T.int64(8) + 
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1)
+                                        v0 = T.axis.spatial(v, 
u_fused_ax0_fused_fused_0 * T.int64(256) + u_fused_ax0_fused_fused_1)
+                                        vax1_0_fused_ax1_1_fused_0, 
vax1_0_fused_ax1_1_fused_1, vax1_0_fused_ax1_1_fused_3 = T.axis.remap("RRR", 
[ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_1, 
ax1_0_fused_ax1_1_fused_3])
+                                        T.where(u_fused_ax0_fused_fused_0 * 
T.int64(256) + u_fused_ax0_fused_fused_1 < v)
+                                        
T.reads(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused,
 T.int64(0), T.int64(0), v0], lv1607_shared[T.int64(0), T.int64(0), 
vax1_0_fused_ax1_1_fused_0 * T.int64(32) + vax1_0_fused_ax1_1_fused_1 * 
T.int64(32) + vax1_0_fused_ax1_1_fused_3 * T.int64(8) + 
vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused], 
lv612_local[vax1_0_fused_ax1_1_fused_0 * T.int64(4) + 
vax1_0_fused_ax1_1_fused_1 * T.int64(4) + [...]
+                                        
T.writes(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused,
 T.int64(0), T.int64(0), v0])
+                                        
var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused,
 T.int64(0), T.int64(0), v0] = 
var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused,
 T.int64(0), T.int64(0), v0] + lv1607_shared[T.int64(0), T.int64(0), 
vax1_0_fused_ax1_1_fused_0 * T.int64(32) + vax1_0_fused_ax1_1_fused_1 * 
T.int64(32) + vax1_0_fused_ax1_1_fused_3 * T.int64(8) + 
vax1_0_fused_ax1_1_fused_2_ax1_0_fused_a [...]
+            for ax2 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
+                for ax0 in T.thread_binding(T.int64(1), thread="threadIdx.y"):
                     with T.block("matmul_rf_init"):
-                        
vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 = 
T.axis.spatial(T.int64(4), ax0)
-                        v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * 
T.int64(64) + ax2)
-                        T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + ax2 
< v)
+                        
vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 = 
T.axis.spatial(T.int64(1), ax0)
+                        v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * 
T.int64(256) + ax2)
+                        T.where(u_fused_ax0_fused_fused_0 * T.int64(256) + ax2 
< v)
                         T.reads()
                         
T.writes(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0,
 T.int64(0), T.int64(0), v0])
                         
var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0,
 T.int64(0), T.int64(0), v0] = T.float16(0)
                     for ax1 in T.serial(T.int64(8), 
annotations={"pragma_auto_unroll_max_step": 8, "pragma_unroll_explicit": 1}):
                         with T.block("matmul_rf_update"):
                             
vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 
vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1 = 
T.axis.remap("SR", [ax0, ax1])
-                            v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * 
T.int64(64) + ax2)
-                            T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + 
ax2 < v)
+                            v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * 
T.int64(256) + ax2)
+                            T.where(u_fused_ax0_fused_fused_0 * T.int64(256) + 
ax2 < v)
                             
T.reads(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0,
 T.int64(0), T.int64(0), v0], 
var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0
 * T.int64(8) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, 
T.int64(0), T.int64(0), v0])
                             
T.writes(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0,
 T.int64(0), T.int64(0), v0])
                             
var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0,
 T.int64(0), T.int64(0), v0] = 
var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0,
 T.int64(0), T.int64(0), v0] + 
var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0
 * T.int64(8) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, 
T.int64(0), T.int64(0), v0]
-            for ax1 in T.thread_binding(T.int64(64), thread="threadIdx.x"):
-                for ax0 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
+            for ax1 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
+                for ax0 in T.thread_binding(T.int64(1), thread="threadIdx.y"):
                     with T.block("matmul"):
-                        
vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 = 
T.axis.reduce(T.int64(4), ax0)
-                        v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * 
T.int64(64) + ax1)
-                        T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + ax1 
< v)
+                        
vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 = 
T.axis.reduce(T.int64(1), ax0)
+                        v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * 
T.int64(256) + ax1)
+                        T.where(u_fused_ax0_fused_fused_0 * T.int64(256) + ax1 
< v)
                         
T.reads(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0,
 T.int64(0), T.int64(0), v0])
                         T.writes(var_matmul_intermediate_local[T.int64(0), 
T.int64(0), v0])
                         with T.init():
                             var_matmul_intermediate_local[T.int64(0), 
T.int64(0), v0] = T.float16(0)
                         var_matmul_intermediate_local[T.int64(0), T.int64(0), 
v0] = var_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + 
var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0,
 T.int64(0), T.int64(0), v0]
-            for ax0_fused_0 in T.thread_binding(T.int64(64), 
thread="threadIdx.x"):
+            for ax0_fused_0 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
                 for ax0_fused_1 in range(T.int64(1)):
                     with T.block("compute"):
-                        v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * 
T.int64(64) + ax0_fused_0 + ax0_fused_1)
-                        T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + 
(ax0_fused_0 + ax0_fused_1) < v)
+                        v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * 
T.int64(256) + ax0_fused_0 + ax0_fused_1)
+                        T.where(u_fused_ax0_fused_fused_0 * T.int64(256) + 
(ax0_fused_0 + ax0_fused_1) < v)
                         T.reads(var_matmul_intermediate_local[T.int64(0), 
T.int64(0), v0])
                         T.writes(p_output0_intermediate[T.int64(0), 
T.int64(0), v0])
                         p_output0_intermediate[T.int64(0), T.int64(0), v0] = 
T.Cast("float32", var_matmul_intermediate_local[T.int64(0), T.int64(0), v0])

Reply via email to