This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new b00ed76318 [Unity][Dlight] Minor performance improvement for gemm and 
gemv (#15278)
b00ed76318 is described below

commit b00ed763183e7c4ad2f9d97ab1d2b9291b838bb7
Author: Siyuan Feng <[email protected]>
AuthorDate: Mon Jul 10 22:00:05 2023 +0800

    [Unity][Dlight] Minor performance improvement for gemm and gemv (#15278)
    
    This PR is to improve the performance of gemm and gemv in dlight:
    
    - gemm: use vectorized load/store and change bindings.
    - gemv: apply unroll
---
 python/tvm/dlight/gpu/decode_gemv.py        |   2 +
 python/tvm/dlight/gpu/matmul.py             |  11 ++-
 tests/python/dlight/test_gpu_decode_gemv.py |  10 +--
 tests/python/dlight/test_gpu_matmul.py      | 112 ++++++++++++++--------------
 4 files changed, 70 insertions(+), 65 deletions(-)

diff --git a/python/tvm/dlight/gpu/decode_gemv.py 
b/python/tvm/dlight/gpu/decode_gemv.py
index d0d37e8476..6c7e31181b 100644
--- a/python/tvm/dlight/gpu/decode_gemv.py
+++ b/python/tvm/dlight/gpu/decode_gemv.py
@@ -206,6 +206,8 @@ class DecodeGEMV(ScheduleRule):
         sch.reorder(bx, tx, r)
         sch.bind(bx, "blockIdx.x")
         sch.bind(tx, "threadIdx.x")
+        sch.annotate(tx, ann_key="pragma_auto_unroll_max_step", ann_val=256)
+        sch.annotate(tx, ann_key="pragma_unroll_explicit", ann_val=1)
         sch.set_scope(rf, 0, "local")
         sch.decompose_reduction(rf, r)
         # Schedule the write back block
diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py
index 86d685e53c..be5e4b02d7 100644
--- a/python/tvm/dlight/gpu/matmul.py
+++ b/python/tvm/dlight/gpu/matmul.py
@@ -297,7 +297,7 @@ class Matmul(ScheduleRule):
         block_size_y = 16
         vthread_x = 1
         vthread_y = 1
-        micro_size_x = 2
+        micro_size_x = 4
         micro_size_y = 4
         micro_size_k = 16
         vector_size = 2
@@ -340,20 +340,23 @@ class Matmul(ScheduleRule):
 
         l2g = sch.cache_write(main_block, 0, "local")
         sch.reverse_compute_at(l2g, tx, preserve_unit_loops=True)
+        if micro_size_y % vector_size == 0:
+            _, v = sch.split(sch.get_loops(l2g)[-1], [None, vector_size])
+            sch.vectorize(v)
 
         def _cooperative_fetch(index, vec_len):
             block = sch.cache_read(main_block, index, "shared")
             num_loops = len(sch.get_loops(block))
             sch.compute_at(block, ko, preserve_unit_loops=True)
             loops = sch.get_loops(block)[-num_loops:]
-            _, ty, tx, vec = sch.split(
+            ty, tx, _, vec = sch.split(
                 sch.fuse(*loops),
-                factors=[None, block_size_y, block_size_x, vec_len],
+                factors=[block_size_y, block_size_x, None, vec_len],
             )
             sch.vectorize(vec)
             sch.bind(ty, "threadIdx.y")
             sch.bind(tx, "threadIdx.x")
-            sch.storage_align(block, 0, axis=1, factor=32, offset=vec_len)
+            sch.storage_align(block, 0, axis=1, factor=8, offset=vec_len)
             return block
 
         a_g2s = _cooperative_fetch(0, vec_len=vector_size)
diff --git a/tests/python/dlight/test_gpu_decode_gemv.py 
b/tests/python/dlight/test_gpu_decode_gemv.py
index bd84aeb096..7b19e6b7f8 100644
--- a/tests/python/dlight/test_gpu_decode_gemv.py
+++ b/tests/python/dlight/test_gpu_decode_gemv.py
@@ -56,7 +56,7 @@ def test_decode_gemv_1():
             # with T.block("root"):
             C_rf_local = T.alloc_buffer((256, 1, 1, 4096), "float16", 
scope="local")
             for i2_i0_i1_fused in T.thread_binding(4096, thread="blockIdx.x"):
-                for k_0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
+                for k_0_fused_1 in T.thread_binding(256, thread="threadIdx.x", 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                     with T.block("matmul_rf_init"):
                         vk_0_fused_1 = T.axis.spatial(256, k_0_fused_1)
                         v_i2 = T.axis.spatial(4096, i2_i0_i1_fused)
@@ -179,7 +179,7 @@ def test_decode_gemv_3():
             # with T.block("root"):
             C_rf_local = T.alloc_buffer((256, 1, 1, 4096), "float16", 
scope="local")
             for i2_0_i0_i1_fused in T.thread_binding(512, thread="blockIdx.x"):
-                for k_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
+                for k_fused_1 in T.thread_binding(256, thread="threadIdx.x", 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                     for i2_1_init in range(8):
                         with T.block("matmul_rf_init"):
                             vk_fused_1 = T.axis.spatial(256, k_fused_1)
@@ -315,7 +315,7 @@ def test_decode_gemv_sigmoid():
             C_local = T.alloc_buffer((1, 1, 4096), "float16", scope="local")
             C_rf_local = T.alloc_buffer((256, 1, 1, 4096), "float16", 
scope="local")
             for i2_i0_i1_fused in T.thread_binding(4096, thread="blockIdx.x"):
-                for k_0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
+                for k_0_fused_1 in T.thread_binding(256, thread="threadIdx.x", 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                     with T.block("matmul_rf_init"):
                         vk_0_fused_1 = T.axis.spatial(256, k_0_fused_1)
                         v_i2 = T.axis.spatial(4096, i2_i0_i1_fused)
@@ -387,7 +387,7 @@ def test_decode_gemv_1_fp32():
             C_fp32_local = T.alloc_buffer((1, 1, 4096), scope="local")
             C_fp32_rf_local = T.alloc_buffer((256, 1, 1, 4096), scope="local")
             for ax0_fused in T.thread_binding(4096, thread="blockIdx.x"):
-                for ax1_0_fused_1 in T.thread_binding(256, 
thread="threadIdx.x"):
+                for ax1_0_fused_1 in T.thread_binding(256, 
thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, 
"pragma_unroll_explicit": 1}):
                     with T.block("matmul_rf_init"):
                         vax1_0_fused_1, v0 = T.axis.remap("SS", 
[ax1_0_fused_1, ax0_fused])
                         T.reads()
@@ -450,7 +450,7 @@ def test_reduction_no_spatial():
             Ared_temp_shared = T.alloc_buffer((1, 1), scope="shared")
             Ared_temp_rf_local = T.alloc_buffer((256, 1, 1), scope="local")
             for ax0_fused in T.thread_binding(T.int64(1), 
thread="blockIdx.x"): # pylint: disable=unused-variable
-                for ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
+                for ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x", 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                     with T.block("Ared_temp_rf_init"):
                         vax1_fused_1 = T.axis.spatial(256, ax1_fused_1)
                         v0 = T.axis.spatial(T.int64(1), T.int64(0))
diff --git a/tests/python/dlight/test_gpu_matmul.py 
b/tests/python/dlight/test_gpu_matmul.py
index 38ecbfee94..f3d9a7089d 100644
--- a/tests/python/dlight/test_gpu_matmul.py
+++ b/tests/python/dlight/test_gpu_matmul.py
@@ -19,8 +19,6 @@ import pytest
 
 import tvm.testing
 from tvm import dlight as dl
-from tvm.ir import assert_structural_equal
-from tvm.script import ir as I
 from tvm.script import tir as T
 from tvm.target import Target
 
@@ -56,67 +54,68 @@ class TestMatmul(BaseBeforeAfter):
         inp0 = T.match_buffer(var_inp0, (T.int64(1), m, T.int64(4096)))
         matmul = T.match_buffer(var_matmul, (T.int64(1), m, T.int64(4096)))
         # with T.block("root"):
-        matmul_reindex_pad_local = T.alloc_buffer((T.int64(1), (m + 
T.int64(15)) // T.int64(16) * T.int64(16), T.int64(4096)), scope="local")
-        inp0_reindex_pad_shared = T.alloc_buffer((T.int64(1), (m + 
T.int64(15)) // T.int64(16) * T.int64(16), T.int64(4096)), scope="shared")
+        matmul_reindex_pad_local = T.alloc_buffer((T.int64(1), (m + 
T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="local")
+        inp0_reindex_pad_shared = T.alloc_buffer((T.int64(1), (m + 
T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="shared")
         inp1_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(4096), 
T.int64(4096)), scope="shared")
         for ax0 in T.thread_binding(T.int64(1), thread="blockIdx.z"):
-            for ax1_0 in T.thread_binding((m + T.int64(15)) // T.int64(16), 
thread="blockIdx.x"):
+            for ax1_0 in T.thread_binding((m + T.int64(31)) // T.int64(32), 
thread="blockIdx.x"):
                 for ax2_0 in T.thread_binding(T.int64(64), 
thread="blockIdx.y"):
                     for ax2_1 in T.thread_binding(T.int64(1), 
thread="vthread.y"):
                         for ax1_1 in T.thread_binding(T.int64(1), 
thread="vthread.x"):
                             for ax2_2 in T.thread_binding(T.int64(16), 
thread="threadIdx.y"):
                                 for ax1_2 in T.thread_binding(T.int64(8), 
thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, 
"pragma_unroll_explicit": 1}):
-                                    for ax2_3_init, ax1_3_init in 
T.grid(T.int64(4), T.int64(2)):
+                                    for ax2_3_init, ax1_3_init in 
T.grid(T.int64(4), T.int64(4)):
                                         with T.block("matmul_init"):
                                             v0 = T.axis.spatial(T.int64(1), 
ax0)
-                                            v1 = T.axis.spatial((m + 
T.int64(15)) // T.int64(16) * T.int64(16), ax1_0 * T.int64(16) + ax1_1 * 
T.int64(16) + ax1_2 * T.int64(2) + ax1_3_init)
+                                            v1 = T.axis.spatial((m + 
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * 
T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
                                             v2 = T.axis.spatial(T.int64(4096), 
ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_init)
                                             T.reads()
                                             
T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2])
                                             
matmul_reindex_pad_local[T.int64(0), v1, v2] = T.float32(0)
                                     for ax3_0 in range(T.int64(256)):
-                                        for ax0_ax1_ax2_fused_0 in 
range(T.int64(1)):
-                                            for ax0_ax1_ax2_fused_1 in 
T.thread_binding(T.int64(16), thread="threadIdx.y"):
-                                                for ax0_ax1_ax2_fused_2 in 
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+                                        for ax0_ax1_ax2_fused_0 in 
T.thread_binding(T.int64(16), thread="threadIdx.y"):
+                                            for ax0_ax1_ax2_fused_1 in 
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+                                                for ax0_ax1_ax2_fused_2 in 
range(T.int64(2)):
                                                     for ax0_ax1_ax2_fused_3 in 
T.vectorized(T.int64(2)):
                                                         with 
T.block("inp0_reindex_pad_shared"):
                                                             v0 = 
T.axis.spatial(T.int64(1), T.int64(0))
-                                                            v1 = 
T.axis.spatial((m + T.int64(15)) // T.int64(16) * T.int64(16), ax1_0 * 
T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * 
T.int64(16) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // 
T.int64(16))
-                                                            v2 = 
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * 
T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 * 
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+                                                            v1 = 
T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * 
T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * 
T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // 
T.int64(16))
+                                                            v2 = 
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * 
T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * 
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
                                                             T.reads(inp0[v0, 
v1, v2])
                                                             
T.writes(inp0_reindex_pad_shared[v0, v1, v2])
-                                                            
T.block_attr({"buffer_dim_align": [[0, 1, 32, 2]]})
+                                                            
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
                                                             
inp0_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, inp0[v0, v1, v2], 
T.float32(0))
-                                        for ax0_ax1_ax2_fused_0 in 
range(T.int64(4)):
-                                            for ax0_ax1_ax2_fused_1 in 
T.thread_binding(T.int64(16), thread="threadIdx.y"):
-                                                for ax0_ax1_ax2_fused_2 in 
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+                                        for ax0_ax1_ax2_fused_0 in 
T.thread_binding(T.int64(16), thread="threadIdx.y"):
+                                            for ax0_ax1_ax2_fused_1 in 
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+                                                for ax0_ax1_ax2_fused_2 in 
range(T.int64(4)):
                                                     for ax0_ax1_ax2_fused_3 in 
T.vectorized(T.int64(2)):
                                                         with 
T.block("inp1_reindex_shared"):
                                                             v0 = 
T.axis.spatial(T.int64(1), T.int64(0))
-                                                            v1 = 
T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + (ax0_ax1_ax2_fused_0 * 
T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 * 
T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
-                                                            v2 = 
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * 
T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 * 
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+                                                            v1 = 
T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + (ax0_ax1_ax2_fused_0 * 
T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * 
T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+                                                            v2 = 
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * 
T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * 
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
                                                             T.reads(inp1[v2, 
v1])
                                                             
T.writes(inp1_reindex_shared[v0, v1, v2])
-                                                            
T.block_attr({"buffer_dim_align": [[0, 1, 32, 2]]})
+                                                            
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
                                                             
inp1_reindex_shared[v0, v1, v2] = inp1[v2, v1]
-                                        for ax3_1, ax2_3, ax1_3 in 
T.grid(T.int64(16), T.int64(4), T.int64(2)):
+                                        for ax3_1, ax2_3, ax1_3 in 
T.grid(T.int64(16), T.int64(4), T.int64(4)):
                                             with T.block("matmul_update"):
                                                 v0 = 
T.axis.spatial(T.int64(1), ax0)
-                                                v1 = T.axis.spatial((m + 
T.int64(15)) // T.int64(16) * T.int64(16), ax1_0 * T.int64(16) + ax1_1 * 
T.int64(16) + ax1_2 * T.int64(2) + ax1_3)
+                                                v1 = T.axis.spatial((m + 
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * 
T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
                                                 v2 = 
T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 
* T.int64(4) + ax2_3)
                                                 v3 = 
T.axis.reduce(T.int64(4096), ax3_0 * T.int64(16) + ax3_1)
                                                 
T.reads(matmul_reindex_pad_local[T.int64(0), v1, v2], 
inp0_reindex_pad_shared[T.int64(0), v1, v3], inp1_reindex_shared[T.int64(0), 
v2, v3])
                                                 
T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2])
                                                 
matmul_reindex_pad_local[T.int64(0), v1, v2] = 
matmul_reindex_pad_local[T.int64(0), v1, v2] + 
inp0_reindex_pad_shared[T.int64(0), v1, v3] * inp1_reindex_shared[T.int64(0), 
v2, v3]
-                                    for ax0_1, ax1, ax2 in T.grid(T.int64(1), 
T.int64(2), T.int64(4)):
-                                        with 
T.block("matmul_reindex_pad_local"):
-                                            v0 = T.axis.spatial(T.int64(1), 
ax0_1)
-                                            v1 = T.axis.spatial((m + 
T.int64(15)) // T.int64(16) * T.int64(16), ax1_0 * T.int64(16) + ax1_2 * 
T.int64(2) + ax1)
-                                            v2 = T.axis.spatial(T.int64(4096), 
ax2_0 * T.int64(64) + ax2_2 * T.int64(4) + ax2)
-                                            
T.reads(matmul_reindex_pad_local[v0, v1, v2])
-                                            T.writes(matmul[T.int64(0), v1, 
v2])
-                                            if v1 < m:
-                                                matmul[T.int64(0), v1, v2] = 
matmul_reindex_pad_local[v0, v1, v2]
+                                    for ax0_1, ax1, ax2_0_1 in 
T.grid(T.int64(1), T.int64(4), T.int64(2)):
+                                        for ax2_1_1 in 
T.vectorized(T.int64(2)):
+                                            with 
T.block("matmul_reindex_pad_local"):
+                                                v0 = 
T.axis.spatial(T.int64(1), ax0_1)
+                                                v1 = T.axis.spatial((m + 
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * 
T.int64(4) + ax1)
+                                                v2 = 
T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_2 * T.int64(4) + 
ax2_0_1 * T.int64(2) + ax2_1_1)
+                                                
T.reads(matmul_reindex_pad_local[v0, v1, v2])
+                                                T.writes(matmul[T.int64(0), 
v1, v2])
+                                                if v1 < m:
+                                                    matmul[T.int64(0), v1, v2] 
= matmul_reindex_pad_local[v0, v1, v2]
     # fmt: on
 
 
@@ -147,70 +146,71 @@ class TestFusedMatmul(BaseBeforeAfter):
                 T.reads(C[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, 
v_ax1, v_ax2])
                 T.writes(Out[v_ax0, v_ax1, v_ax2])
                 Out[v_ax0, v_ax1, v_ax2] = C[v_ax0, v_ax1, v_ax2] + 
var_matmul_intermediate[v_ax0, v_ax1, v_ax2]
-
     @T.prim_func
     def expected(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S: 
T.Buffer((T.int64(128), T.int64(4096)), "uint32"), A: T.Buffer((T.int64(1), 
T.int64(32), T.int64(4096)), "float32"), C: T.Buffer((T.int64(1), T.int64(32), 
T.int64(4096)), "float32"), Out: T.Buffer((T.int64(1), T.int64(32), 
T.int64(4096)), "float32")):
         T.func_attr({"tir.is_scheduled": 1})
+        # with T.block("root"):
         var_matmul_intermediate_reindex_local = T.alloc_buffer((T.int64(1), 
T.int64(32), T.int64(4096)), scope="local")
         A_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(32), 
T.int64(4096)), scope="shared")
         var_decode_intermediate_reindex_shared = T.alloc_buffer((T.int64(1), 
T.int64(4096), T.int64(4096)), scope="shared")
         for ax0 in T.thread_binding(T.int64(1), thread="blockIdx.z"):
-            for ax1_0 in T.thread_binding(T.int64(2), thread="blockIdx.x"):
+            for ax1_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"):
                 for ax2_0 in T.thread_binding(T.int64(64), 
thread="blockIdx.y"):
                     for ax2_1 in T.thread_binding(T.int64(1), 
thread="vthread.y"):
                         for ax1_1 in T.thread_binding(T.int64(1), 
thread="vthread.x"):
                             for ax2_2 in T.thread_binding(T.int64(16), 
thread="threadIdx.y"):
                                 for ax1_2 in T.thread_binding(T.int64(8), 
thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, 
"pragma_unroll_explicit": 1}):
-                                    for ax2_3_init, ax1_3_init in 
T.grid(T.int64(4), T.int64(2)):
+                                    for ax2_3_init, ax1_3_init in 
T.grid(T.int64(4), T.int64(4)):
                                         with T.block("matmul_init"):
                                             v0 = T.axis.spatial(T.int64(1), 
ax0)
-                                            v1 = T.axis.spatial(T.int64(32), 
ax1_0 * T.int64(16) + ax1_1 * T.int64(16) + ax1_2 * T.int64(2) + ax1_3_init)
+                                            v1 = T.axis.spatial(T.int64(32), 
ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
                                             v2 = T.axis.spatial(T.int64(4096), 
ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_init)
                                             T.reads()
                                             
T.writes(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2])
                                             
var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] = T.float32(0)
                                     for ax3_0 in range(T.int64(256)):
-                                        for ax0_ax1_ax2_fused_0 in 
range(T.int64(1)):
-                                            for ax0_ax1_ax2_fused_1 in 
T.thread_binding(T.int64(16), thread="threadIdx.y"):
-                                                for ax0_ax1_ax2_fused_2 in 
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+                                        for ax0_ax1_ax2_fused_0 in 
T.thread_binding(T.int64(16), thread="threadIdx.y"):
+                                            for ax0_ax1_ax2_fused_1 in 
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+                                                for ax0_ax1_ax2_fused_2 in 
range(T.int64(2)):
                                                     for ax0_ax1_ax2_fused_3 in 
T.vectorized(T.int64(2)):
                                                         with 
T.block("A_reindex_shared"):
                                                             v0 = 
T.axis.spatial(T.int64(1), T.int64(0))
-                                                            v1 = 
T.axis.spatial(T.int64(32), ax1_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * 
T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 * 
T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
-                                                            v2 = 
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * 
T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 * 
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+                                                            v1 = 
T.axis.spatial(T.int64(32), (ax0_ax1_ax2_fused_0 * T.int64(32) + 
ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + 
ax0_ax1_ax2_fused_3) // T.int64(16))
+                                                            v2 = 
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * 
T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * 
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
                                                             T.reads(A[v0, v1, 
v2])
                                                             
T.writes(A_reindex_shared[v0, v1, v2])
-                                                            
T.block_attr({"buffer_dim_align": [[0, 1, 32, 2]]})
+                                                            
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
                                                             
A_reindex_shared[v0, v1, v2] = A[v0, v1, v2]
-                                        for ax0_ax1_ax2_fused_0 in 
range(T.int64(4)):
-                                            for ax0_ax1_ax2_fused_1 in 
T.thread_binding(T.int64(16), thread="threadIdx.y"):
-                                                for ax0_ax1_ax2_fused_2 in 
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+                                        for ax0_ax1_ax2_fused_0 in 
T.thread_binding(T.int64(16), thread="threadIdx.y"):
+                                            for ax0_ax1_ax2_fused_1 in 
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+                                                for ax0_ax1_ax2_fused_2 in 
range(T.int64(4)):
                                                     for ax0_ax1_ax2_fused_3 in 
T.vectorized(T.int64(2)):
                                                         with 
T.block("var_decode_intermediate_reindex_shared"):
                                                             v0 = 
T.axis.spatial(T.int64(1), T.int64(0))
-                                                            v1 = 
T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + (ax0_ax1_ax2_fused_0 * 
T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 * 
T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
-                                                            v2 = 
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * 
T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 * 
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+                                                            v1 = 
T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + (ax0_ax1_ax2_fused_0 * 
T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * 
T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+                                                            v2 = 
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * 
T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * 
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
                                                             T.reads(W[v2 // 
T.int64(8), v1], S[v2 // T.int64(32), v1])
                                                             
T.writes(var_decode_intermediate_reindex_shared[v0, v1, v2])
-                                                            
T.block_attr({"buffer_dim_align": [[0, 1, 32, 2]]})
+                                                            
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
                                                             
var_decode_intermediate_reindex_shared[v0, v1, v2] = T.Cast("float32", 
T.bitwise_and(T.shift_right(W[v2 // T.int64(8), v1], T.Cast("uint32", v2 % 
T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", 
T.shift_left(T.bitwise_and(S[v2 // T.int64(32), v1], T.uint32(65535)), 
T.uint32(16))) + T.reinterpret("float32", 
T.shift_left(T.bitwise_and(T.shift_right(S[v2 // T.int64(32), v1], 
T.uint32(16)), T.uint32(65535)), T.ui [...]
-                                        for ax3_1, ax2_3, ax1_3 in 
T.grid(T.int64(16), T.int64(4), T.int64(2)):
+                                        for ax3_1, ax2_3, ax1_3 in 
T.grid(T.int64(16), T.int64(4), T.int64(4)):
                                             with T.block("matmul_update"):
                                                 v0 = 
T.axis.spatial(T.int64(1), ax0)
-                                                v1 = 
T.axis.spatial(T.int64(32), ax1_0 * T.int64(16) + ax1_1 * T.int64(16) + ax1_2 * 
T.int64(2) + ax1_3)
+                                                v1 = 
T.axis.spatial(T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * 
T.int64(4) + ax1_3)
                                                 v2 = 
T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 
* T.int64(4) + ax2_3)
                                                 v3 = 
T.axis.reduce(T.int64(4096), ax3_0 * T.int64(16) + ax3_1)
                                                 
T.reads(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2], 
A_reindex_shared[T.int64(0), v1, v3], 
var_decode_intermediate_reindex_shared[T.int64(0), v2, v3])
                                                 
T.writes(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2])
                                                 
var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] = 
var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] + 
A_reindex_shared[T.int64(0), v1, v3] * 
var_decode_intermediate_reindex_shared[T.int64(0), v2, v3]
-                                    for ax0_1, ax1, ax2 in T.grid(T.int64(1), 
T.int64(2), T.int64(4)):
-                                        with 
T.block("var_matmul_intermediate_reindex_local"):
-                                            v0 = T.axis.spatial(T.int64(1), 
ax0_1)
-                                            v1 = T.axis.spatial(T.int64(32), 
ax1_0 * T.int64(16) + ax1_2 * T.int64(2) + ax1)
-                                            v2 = T.axis.spatial(T.int64(4096), 
ax2_0 * T.int64(64) + ax2_2 * T.int64(4) + ax2)
-                                            T.reads(C[T.int64(0), v1, v2], 
var_matmul_intermediate_reindex_local[v0, v1, v2])
-                                            T.writes(Out[T.int64(0), v1, v2])
-                                            Out[T.int64(0), v1, v2] = 
C[T.int64(0), v1, v2] + var_matmul_intermediate_reindex_local[v0, v1, v2]
+                                    for ax0_1, ax1, ax2_0_1 in 
T.grid(T.int64(1), T.int64(4), T.int64(2)):
+                                        for ax2_1_1 in 
T.vectorized(T.int64(2)):
+                                            with 
T.block("var_matmul_intermediate_reindex_local"):
+                                                v0 = 
T.axis.spatial(T.int64(1), ax0_1)
+                                                v1 = 
T.axis.spatial(T.int64(32), ax1_2 * T.int64(4) + ax1)
+                                                v2 = 
T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_2 * T.int64(4) + 
ax2_0_1 * T.int64(2) + ax2_1_1)
+                                                T.reads(C[T.int64(0), v1, v2], 
var_matmul_intermediate_reindex_local[v0, v1, v2])
+                                                T.writes(Out[T.int64(0), v1, 
v2])
+                                                Out[T.int64(0), v1, v2] = 
C[T.int64(0), v1, v2] + var_matmul_intermediate_reindex_local[v0, v1, v2]
     # fmt: on
 
 

Reply via email to