This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 50d1c97dc9 [DLIGHT][GPU] Add OpenCL dequant matmul schedule (#17187)
50d1c97dc9 is described below

commit 50d1c97dc982c6ddfe089852d1fbbac3ea629851
Author: krishnaraj36 <[email protected]>
AuthorDate: Tue Jul 23 20:57:53 2024 +0530

    [DLIGHT][GPU] Add OpenCL dequant matmul schedule (#17187)
    
    * [DLIGHT][GPU] Add OpenCL dequant matmul schedule
    
    1. Enhanced the GPU matmul schedule for OpenCL Android and windows backend.
    2. It improves the 2X performance gain for Llama-2-7B prefill process
    Model                       device                  Earlier prefill perf    
  Optimized prefill perf
    Llama-2-7B-chat-hf      Snapdragon® 8 Gen 3            27 tok/sec           
         50 tok/sec
    
    * Update matmul.py
---
 python/tvm/dlight/gpu/matmul.py        | 144 +++++++++++++++++++++++--
 tests/python/dlight/test_gpu_matmul.py | 192 +++++++++++++++++++++++++++------
 2 files changed, 292 insertions(+), 44 deletions(-)

diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py
index a5759941ca..25cc649b44 100644
--- a/python/tvm/dlight/gpu/matmul.py
+++ b/python/tvm/dlight/gpu/matmul.py
@@ -27,7 +27,7 @@ from tvm.tir import IterVar, PrimExpr, Var
 from tvm.tir.analysis import undefined_vars
 from tvm.tir.schedule.schedule import BlockRV
 
-from ..base import analysis
+from ..base import analysis, BlockInfo, IterInfo
 from .base import GPUScheduleRule
 
 
@@ -273,6 +273,32 @@ def get_index_map(block: tir.Block) -> 
Optional[Tuple[tir.IndexMap, ...]]:
     )
 
 
+def get_block_info(sch: tir.Schedule, block: tir.schedule.BlockRV) -> 
BlockInfo:
+    def _iter_kind(loop: tir.IterVar) -> str:
+        return {tir.IterVar.DataPar: "S", tir.IterVar.CommReduce: 
"R"}.get(loop.iter_type, "O")
+
+    def _is_reduction_block(block: tir.schedule.BlockRV):
+        for iter_var in sch.get(block).iter_vars:
+            if _iter_kind(iter_var) == "R":
+                return True
+        return False
+
+    return BlockInfo(
+        name=sch.get(block).name_hint,
+        iters=[
+            IterInfo(
+                kind=_iter_kind(iter_var),
+                var=iter_var.var,
+                dom=iter_var.dom.extent,
+                loop_rv=loop_rv,
+            )
+            for loop_rv, iter_var in zip(sch.get_loops(block), 
sch.get(block).iter_vars)
+        ],
+        block_rv=block,
+        reduction_block=_is_reduction_block(block),
+    )
+
+
 def get_reduction_blocks(sch, blocks) -> bool:
     # Get the main computation block
     def is_reduction(block: BlockRV) -> bool:
@@ -914,17 +940,19 @@ class Matmul(GPUScheduleRule):
                 storage_align=True,
                 inner_x=False,
             )
-        elif target.kind.name == "opencl" and "android" in str(target.host):
+        elif target.kind.name == "opencl" and (
+            ("android" in str(target.host)) or ("windows" in str(target.host))
+        ):
             return Matmul.Config(
-                block_size_x=8,
-                block_size_y=16,
+                block_size_x=32,
+                block_size_y=8,
                 vthread_x=1,
                 vthread_y=1,
                 micro_size_x=8,
                 micro_size_y=2,
                 micro_size_k=16,
                 vector_size=8,
-                unroll=64,
+                unroll=4,
                 use_shared=False,
                 storage_align=False,
                 inner_x=True,
@@ -941,6 +969,7 @@ class Matmul(GPUScheduleRule):
         if not isinstance(func, tir.PrimFunc) or not 
self.is_target_available(target):
             return None
         sch = tir.Schedule(func)
+        config = self.get_configs(target)
         root_block = analysis.get_root_block(sch)
         blocks = sch.get_child_blocks(root_block)
 
@@ -953,9 +982,22 @@ class Matmul(GPUScheduleRule):
         index_maps = get_index_map(block_stmt)
         if index_maps is None:
             return None
-        matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
+
+        main_block_info = get_block_info(sch, main_block)
+        iter_infos = main_block_info.iters
+
+        # Checks if it's a inner reduction by getting the last matrix's inner 
Index
+        def is_inner_reduction(block_stmt, iter_infos):
+            end_it = block_stmt.reads[-1].region[-1].min
+            return {it.var: it.kind for it in iter_infos}.get(end_it, "O") == 
"R"
+
+        if target.kind.name == "opencl" and not is_inner_reduction(block_stmt, 
iter_infos):
+            ret = self.sch_outer_reduction(sch, config, main_block, blocks)
+            if ret is not None:
+                return ret
 
         # Step 0. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, 
J, K]
+        matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
         block = sch.reindex(main_block, ("read", 0))
         sch.transform_layout(block, ("write", 0), a_index_map)
         block = sch.reindex(main_block, ("read", 1))
@@ -994,10 +1036,7 @@ class Matmul(GPUScheduleRule):
             except:  # pylint: disable=bare-except
                 pass
 
-        # Step 2. Get schedule config.
-        config = self.get_configs(target)
-
-        # Step 3. Schedule matmul
+        # Step 2. Schedule matmul
         y_kernel_size = config.vthread_y * config.block_size_y * 
config.micro_size_y
         x_kernel_size = config.vthread_x * config.block_size_x * 
config.micro_size_x
         if config.inner_x:
@@ -1075,3 +1114,88 @@ class Matmul(GPUScheduleRule):
 
         sch.decompose_reduction(main_block, ko)
         return sch
+
+    def sch_outer_reduction(
+        self,
+        sch: tir.Schedule,
+        config: Config,
+        reduction_block: tir.schedule.BlockRV,
+        blocks: List[tir.schedule.BlockRV],
+    ) -> Optional[tir.Schedule]:
+        reduction_loops = sch.get_loops(reduction_block)
+        if not len(reduction_loops) == 4:
+            return None
+
+        mb, ms, n, k = reduction_loops
+        if not (
+            isinstance(sch.get(n).extent, tir.IntImm)
+            and isinstance(sch.get(mb).extent, tir.IntImm)
+            and isinstance(sch.get(ms).extent, tir.Var)
+        ):
+            return None
+
+        Threads_X, Threads_Y, VecSize, Unroll_M = (
+            config.block_size_x,
+            config.block_size_y,
+            config.vector_size,
+            config.unroll,
+        )
+
+        is_dequant_block = len(blocks) > 1
+        if is_dequant_block:
+            compute_block, dequant_block, matmul_block = blocks
+            sch.compute_inline(compute_block)
+        else:
+            (matmul_block,) = blocks
+
+        m = sch.fuse(mb, ms)
+
+        sch.pad_einsum(matmul_block, [1, Threads_Y * Unroll_M, Threads_X * 
VecSize, 1])
+
+        rmat_block, wmat_block = (
+            sch.get_producers(matmul_block)[0],
+            sch.get_consumers(matmul_block)[0],
+        )
+        mo, mi, mu = sch.split(m, [None, Threads_Y, Unroll_M])
+        no, ni, nv = sch.split(n, [None, Threads_X, VecSize])
+        k0, k1, k2, k3 = sch.split(k, [None, (Threads_X * VecSize) // 32, 4, 
8])
+        sch.reorder(no, mo, ni, mi, k0, k1, k2, k3, mu, nv)
+
+        sch.compute_at(rmat_block, k0)
+        if is_dequant_block:
+            sch.compute_at(dequant_block, k3)
+        sch.reverse_compute_at(wmat_block, mi)
+        sch.set_scope(rmat_block, 0, "shared")
+        sch.set_scope(matmul_block, 0, "local")
+        if is_dequant_block:
+            sch.set_scope(dequant_block, 0, "local")
+
+        sch.bind(mo, "blockIdx.y")
+        sch.bind(no, "blockIdx.x")
+        sch.bind(mi, "threadIdx.y")
+        sch.bind(ni, "threadIdx.x")
+        sch.vectorize(sch.get_loops(matmul_block)[-1])
+        if is_dequant_block:
+            sch.vectorize(sch.get_loops(dequant_block)[-1])
+
+        # Co-operative Memory Fetch
+        ro, rv = sch.split(sch.get_loops(rmat_block)[-1], [None, VecSize])
+        sch.bind(ro, "threadIdx.x")
+        sch.vectorize(rv)
+
+        wv = sch.get_loops(wmat_block)[-1]
+        sch.vectorize(wv)
+
+        # Scale and Quant Cache
+        if is_dequant_block:
+            qb = sch.cache_read(dequant_block, 0, "local")
+            sb = sch.cache_read(dequant_block, 1, "local")
+            sch.compute_at(sb, k1)
+            sch.compute_at(qb, k2)
+            sch.set_scope(sb, 0, "local")
+            sch.set_scope(qb, 0, "local")
+            sch.vectorize(sch.get_loops(qb)[-1])
+            sch.vectorize(sch.get_loops(sb)[-1])
+
+        sch.decompose_reduction(matmul_block, k0)
+        return sch
diff --git a/tests/python/dlight/test_gpu_matmul.py 
b/tests/python/dlight/test_gpu_matmul.py
index ca32c286ab..4cef7f1c27 100644
--- a/tests/python/dlight/test_gpu_matmul.py
+++ b/tests/python/dlight/test_gpu_matmul.py
@@ -634,42 +634,166 @@ class TestMatmulAndroid(AndroidBeforeAfter):
         inp0 = T.match_buffer(var_inp0, (T.int64(1), m, T.int64(4096)))
         matmul = T.match_buffer(var_matmul, (T.int64(1), m, T.int64(4096)))
         # with T.block("root"):
-        matmul_reindex_pad_local = T.alloc_buffer((T.int64(1), (m + 
T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="local")
-        for ax0_ax1_0_fused in T.thread_binding((m + T.int64(31)) // 
T.int64(32), thread="blockIdx.y"):
-            for ax2_0 in T.thread_binding(T.int64(64), thread="blockIdx.x"):
-                for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
-                    for ax2_1 in T.thread_binding(T.int64(1), 
thread="vthread.x"):
-                        for ax1_2 in T.thread_binding(T.int64(16), 
thread="threadIdx.y"):
-                            for ax2_2 in T.thread_binding(T.int64(8), 
thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 64, 
"pragma_unroll_explicit": 1}):
-                                for ax1_3_init, ax2_3_0_init in 
T.grid(T.int64(2), T.int64(1)):
-                                    for ax2_3_1_init in 
T.vectorized(T.int64(8)):
-                                        with T.block("matmul_init"):
+        inp0_pad_shared = T.alloc_buffer((T.int64(1), (m + T.int64(31)) // 
T.int64(32) * T.int64(32), T.int64(4096)), scope="shared")
+        matmul_pad_local = T.alloc_buffer((T.int64(1), (m + T.int64(31)) // 
T.int64(32) * T.int64(32), T.int64(4096)), scope="local")
+        for i2_0 in T.thread_binding(T.int64(16), thread="blockIdx.x"):
+            for i0_i1_fused_0 in T.thread_binding((m + T.int64(31)) // 
T.int64(32), thread="blockIdx.y"):
+                for i2_1 in T.thread_binding(T.int64(32), 
thread="threadIdx.x"):
+                    for i0_i1_fused_1 in T.thread_binding(T.int64(8), 
thread="threadIdx.y"):
+                        for i0_i1_fused_2_init in range(T.int64(4)):
+                            for i2_2_init in T.vectorized(T.int64(8)):
+                                with T.block("matmul_init"):
+                                    v_i0 = T.axis.spatial(T.int64(1), 
T.int64(0))
+                                    v_i1 = T.axis.spatial((m + T.int64(31)) // 
T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * 
T.int64(4) + i0_i1_fused_2_init)
+                                    v_i2 = T.axis.spatial(T.int64(4096), i2_0 
* T.int64(256) + i2_1 * T.int64(8) + i2_2_init)
+                                    T.reads()
+                                    T.writes(matmul_pad_local[v_i0, v_i1, 
v_i2])
+                                    matmul_pad_local[v_i0, v_i1, v_i2] = 
T.float32(0)
+                        for k_0 in range(T.int64(16)):
+                            for ax0 in range(T.int64(4)):
+                                for ax1_0 in T.thread_binding(T.int64(32), 
thread="threadIdx.x"):
+                                    for ax1_1 in T.vectorized(T.int64(8)):
+                                        with T.block("inp0_pad"):
                                             v0 = T.axis.spatial(T.int64(1), 
T.int64(0))
-                                            v1 = T.axis.spatial((m + 
T.int64(31)) // T.int64(32) * T.int64(32), ax0_ax1_0_fused * T.int64(32) + 
ax1_1 * T.int64(32) + ax1_2 * T.int64(2) + ax1_3_init)
-                                            v2 = T.axis.spatial(T.int64(4096), 
ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(8) + ax2_3_0_init * 
T.int64(8) + ax2_3_1_init)
-                                            T.reads()
-                                            
T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2])
-                                            
matmul_reindex_pad_local[T.int64(0), v1, v2] = T.float32(0)
-                                for ax3_0, ax3_1, ax1_3, ax2_3_0 in 
T.grid(T.int64(256), T.int64(16), T.int64(2), T.int64(1)):
-                                    for ax2_3_1 in T.vectorized(T.int64(8)):
-                                        with T.block("matmul_update"):
+                                            v1 = T.axis.spatial((m + 
T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + 
i0_i1_fused_1 * T.int64(4) + ax0)
+                                            v2 = T.axis.spatial(T.int64(4096), 
k_0 * T.int64(256) + ax1_0 * T.int64(8) + ax1_1)
+                                            T.reads(inp0[v0, v1, v2])
+                                            T.writes(inp0_pad_shared[v0, v1, 
v2])
+                                            inp0_pad_shared[v0, v1, v2] = 
T.if_then_else(v1 < m, inp0[v0, v1, v2], T.float32(0))
+                            for k_1, k_2, k_3, i0_i1_fused_2 in 
T.grid(T.int64(8), T.int64(4), T.int64(8), T.int64(4)):
+                                for i2_2 in T.vectorized(T.int64(8)):
+                                    with T.block("matmul_update"):
+                                        v_i0 = T.axis.spatial(T.int64(1), 
T.int64(0))
+                                        v_i1 = T.axis.spatial((m + 
T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + 
i0_i1_fused_1 * T.int64(4) + i0_i1_fused_2)
+                                        v_i2 = T.axis.spatial(T.int64(4096), 
i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2)
+                                        v_k = T.axis.reduce(T.int64(4096), k_0 
* T.int64(256) + k_1 * T.int64(32) + k_2 * T.int64(8) + k_3)
+                                        T.reads(matmul_pad_local[v_i0, v_i1, 
v_i2], inp0_pad_shared[v_i0, v_i1, v_k], inp1[v_k, v_i2])
+                                        T.writes(matmul_pad_local[v_i0, v_i1, 
v_i2])
+                                        matmul_pad_local[v_i0, v_i1, v_i2] = 
matmul_pad_local[v_i0, v_i1, v_i2] + inp0_pad_shared[v_i0, v_i1, v_k] * 
inp1[v_k, v_i2]
+                        for ax0 in range(T.int64(4)):
+                            for ax1 in T.vectorized(T.int64(8)):
+                                with T.block("matmul_pad"):
+                                    v0 = T.axis.spatial(T.int64(1), T.int64(0))
+                                    v1 = T.axis.spatial(m, i0_i1_fused_0 * 
T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0)
+                                    v2 = T.axis.spatial(T.int64(4096), i2_0 * 
T.int64(256) + i2_1 * T.int64(8) + ax1)
+                                    T.where((i0_i1_fused_0 - (m + T.int64(31)) 
// T.int64(32) < T.int64(0) or i0_i1_fused_0 == T.int64(0)) and i0_i1_fused_0 * 
T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0 < m)
+                                    T.reads(matmul_pad_local[v0, v1, v2])
+                                    T.writes(matmul[v0, v1, v2])
+                                    matmul[v0, v1, v2] = matmul_pad_local[v0, 
v1, v2]
+
+
+class TestFusedDequantMatmulAndroid(AndroidBeforeAfter):
+    # fmt: off
+    @T.prim_func
+    def before(lv840: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), 
lv841: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm260: 
T.handle, p_output0: T.handle):
+        T.func_attr({"tir.noalias": T.bool(True)})
+        seq_len = T.int64()
+        rms_norm260 = T.match_buffer(p_rms_norm260, (T.int64(1), seq_len, 
T.int64(4096)), "float16")
+        matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, 
T.int64(12288)), "float16")
+        # with T.block("root"):
+        compute = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16")
+        dequantize_intermediate_intermediate = T.alloc_buffer((T.int64(4096), 
T.int64(12288)), "float16")
+        for i0, i1 in T.grid(T.int64(4096), T.int64(12288)):
+            with T.block("compute"):
+                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                T.reads(lv840[v_i0 // T.int64(8), v_i1])
+                T.writes(compute[v_i0, v_i1])
+                compute[v_i0, v_i1] = T.Cast("float16", 
T.bitwise_and(T.shift_right(lv840[v_i0 // T.int64(8), v_i1], T.Cast("uint32", 
v_i0 % T.int64(8) * T.int64(4))), T.uint32(15)))
+        for i0, i1 in T.grid(T.int64(4096), T.int64(12288)):
+            with T.block("dequantize"):
+                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                T.reads(compute[v_i0, v_i1], lv841[v_i0 // T.int64(32), v_i1])
+                T.writes(dequantize_intermediate_intermediate[v_i0, v_i1])
+                dequantize_intermediate_intermediate[v_i0, v_i1] = 
(compute[v_i0, v_i1] - T.float16(7)) * lv841[v_i0 // T.int64(32), v_i1]
+        for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(12288), 
T.int64(4096)):
+            with T.block("matmul"):
+                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+                T.reads(rms_norm260[v_i0, v_i1, v_k], 
dequantize_intermediate_intermediate[v_k, v_i2])
+                T.writes(matmul_intermediate[v_i0, v_i1, v_i2])
+                with T.init():
+                    matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
+                matmul_intermediate[v_i0, v_i1, v_i2] = 
matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm260[v_i0, v_i1, v_k] * 
dequantize_intermediate_intermediate[v_k, v_i2]
+
+    @T.prim_func
+    def expected(lv840: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), 
lv841: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm260: 
T.handle, p_output0: T.handle):
+        T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, 
"tir.noalias": T.bool(True)})
+        seq_len = T.int64()
+        rms_norm260 = T.match_buffer(p_rms_norm260, (T.int64(1), seq_len, 
T.int64(4096)), "float16")
+        matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, 
T.int64(12288)), "float16")
+        # with T.block("root"):
+        dequantize_intermediate_intermediate_local = 
T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16", scope="local")
+        rms_norm260_pad_shared = T.alloc_buffer((T.int64(1), (seq_len + 
T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), "float16", 
scope="shared")
+        matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), (seq_len + 
T.int64(31)) // T.int64(32) * T.int64(32), T.int64(12288)), "float16", 
scope="local")
+        lv840_local = T.alloc_buffer((T.int64(512), T.int64(12288)), "uint32", 
scope="local")
+        lv841_local = T.alloc_buffer((T.int64(128), T.int64(12288)), 
"float16", scope="local")
+        for i2_0 in T.thread_binding(T.int64(48), thread="blockIdx.x"):
+            for i0_i1_fused_0 in T.thread_binding((seq_len + T.int64(31)) // 
T.int64(32), thread="blockIdx.y"):
+                for i2_1 in T.thread_binding(T.int64(32), 
thread="threadIdx.x"):
+                    for i0_i1_fused_1 in T.thread_binding(T.int64(8), 
thread="threadIdx.y"):
+                        for i0_i1_fused_2_init in range(T.int64(4)):
+                            for i2_2_init in T.vectorized(T.int64(8)):
+                                with T.block("matmul_init"):
+                                    v_i0 = T.axis.spatial(T.int64(1), 
T.int64(0))
+                                    v_i1 = T.axis.spatial((seq_len + 
T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + 
i0_i1_fused_1 * T.int64(4) + i0_i1_fused_2_init)
+                                    v_i2 = T.axis.spatial(T.int64(12288), i2_0 
* T.int64(256) + i2_1 * T.int64(8) + i2_2_init)
+                                    T.reads()
+                                    
T.writes(matmul_intermediate_pad_local[v_i0, v_i1, v_i2])
+                                    matmul_intermediate_pad_local[v_i0, v_i1, 
v_i2] = T.float16(0)
+                        for k_0 in range(T.int64(16)):
+                            for ax0 in range(T.int64(4)):
+                                for ax1_0 in T.thread_binding(T.int64(32), 
thread="threadIdx.x"):
+                                    for ax1_1 in T.vectorized(T.int64(8)):
+                                        with T.block("rms_norm260_pad"):
                                             v0 = T.axis.spatial(T.int64(1), 
T.int64(0))
-                                            v1 = T.axis.spatial((m + 
T.int64(31)) // T.int64(32) * T.int64(32), ax0_ax1_0_fused * T.int64(32) + 
ax1_1 * T.int64(32) + ax1_2 * T.int64(2) + ax1_3)
-                                            v2 = T.axis.spatial(T.int64(4096), 
ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(8) + ax2_3_0 * 
T.int64(8) + ax2_3_1)
-                                            v3 = T.axis.reduce(T.int64(4096), 
ax3_0 * T.int64(16) + ax3_1)
-                                            
T.reads(matmul_reindex_pad_local[T.int64(0), v1, v2], inp0[T.int64(0), v1, v3], 
inp1[v3, v2])
-                                            
T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2])
-                                            
matmul_reindex_pad_local[T.int64(0), v1, v2] = 
matmul_reindex_pad_local[T.int64(0), v1, v2] + T.if_then_else(v1 < m, 
inp0[T.int64(0), v1, v3], T.float32(0)) * inp1[v3, v2]
-                                for ax0, ax1, ax2_0_1 in T.grid(T.int64(1), 
T.int64(2), T.int64(1)):
-                                    for ax2_1_1 in T.vectorized(T.int64(8)):
-                                        with 
T.block("matmul_reindex_pad_local"):
-                                            v0 = T.axis.spatial(T.int64(1), 
ax0)
-                                            v1 = T.axis.spatial((m + 
T.int64(31)) // T.int64(32) * T.int64(32), ax0_ax1_0_fused * T.int64(32) + 
ax1_2 * T.int64(2) + ax1)
-                                            v2 = T.axis.spatial(T.int64(4096), 
ax2_0 * T.int64(64) + ax2_2 * T.int64(8) + ax2_0_1 * T.int64(8) + ax2_1_1)
-                                            T.where(ax0_ax1_0_fused * 
T.int64(32) + ax1_2 * T.int64(2) + ax1 < m)
-                                            
T.reads(matmul_reindex_pad_local[v0, v1, v2])
-                                            T.writes(matmul[T.int64(0), v1, 
v2])
-                                            matmul[T.int64(0), v1, v2] = 
matmul_reindex_pad_local[v0, v1, v2]
+                                            v1 = T.axis.spatial((seq_len + 
T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + 
i0_i1_fused_1 * T.int64(4) + ax0)
+                                            v2 = T.axis.spatial(T.int64(4096), 
k_0 * T.int64(256) + ax1_0 * T.int64(8) + ax1_1)
+                                            T.reads(rms_norm260[v0, v1, v2])
+                                            
T.writes(rms_norm260_pad_shared[v0, v1, v2])
+                                            rms_norm260_pad_shared[v0, v1, v2] 
= T.if_then_else(v1 < seq_len, rms_norm260[v0, v1, v2], T.float16(0))
+                            for k_1 in range(T.int64(8)):
+                                for ax0 in T.vectorized(T.int64(8)):
+                                    with T.block("lv841_local"):
+                                        v0 = T.axis.spatial(T.int64(128), k_0 
* T.int64(8) + k_1)
+                                        v1 = T.axis.spatial(T.int64(12288), 
i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0)
+                                        T.reads(lv841[v0, v1])
+                                        T.writes(lv841_local[v0, v1])
+                                        lv841_local[v0, v1] = lv841[v0, v1]
+                                for k_2 in range(T.int64(4)):
+                                    for ax0 in T.vectorized(T.int64(8)):
+                                        with T.block("lv840_local"):
+                                            v0 = T.axis.spatial(T.int64(512), 
k_0 * T.int64(32) + k_1 * T.int64(4) + k_2)
+                                            v1 = 
T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0)
+                                            T.reads(lv840[v0, v1])
+                                            T.writes(lv840_local[v0, v1])
+                                            lv840_local[v0, v1] = lv840[v0, v1]
+                                    for k_3 in range(T.int64(8)):
+                                        for ax0 in T.vectorized(T.int64(8)):
+                                            with T.block("dequantize"):
+                                                v_i0 = 
T.axis.spatial(T.int64(4096), k_0 * T.int64(256) + k_1 * T.int64(32) + k_2 * 
T.int64(8) + k_3)
+                                                v_i1 = 
T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0)
+                                                T.reads(lv840_local[v_i0 // 
T.int64(8), v_i1], lv841_local[v_i0 // T.int64(32), v_i1])
+                                                
T.writes(dequantize_intermediate_intermediate_local[v_i0, v_i1])
+                                                
dequantize_intermediate_intermediate_local[v_i0, v_i1] = (T.Cast("float16", 
T.bitwise_and(T.shift_right(lv840_local[v_i0 // T.int64(8), v_i1], 
T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15))) - 
T.float16(7)) * lv841_local[v_i0 // T.int64(32), v_i1]
+                                        for i0_i1_fused_2 in range(T.int64(4)):
+                                            for i2_2 in 
T.vectorized(T.int64(8)):
+                                                with T.block("matmul_update"):
+                                                    v_i0 = 
T.axis.spatial(T.int64(1), T.int64(0))
+                                                    v_i1 = 
T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), 
i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + i0_i1_fused_2)
+                                                    v_i2 = 
T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2)
+                                                    v_k = 
T.axis.reduce(T.int64(4096), k_0 * T.int64(256) + k_1 * T.int64(32) + k_2 * 
T.int64(8) + k_3)
+                                                    
T.reads(matmul_intermediate_pad_local[v_i0, v_i1, v_i2], 
rms_norm260_pad_shared[v_i0, v_i1, v_k], 
dequantize_intermediate_intermediate_local[v_k, v_i2])
+                                                    
T.writes(matmul_intermediate_pad_local[v_i0, v_i1, v_i2])
+                                                    
matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = 
matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + rms_norm260_pad_shared[v_i0, 
v_i1, v_k] * dequantize_intermediate_intermediate_local[v_k, v_i2]
+                        for ax0 in range(T.int64(4)):
+                            for ax1 in T.vectorized(T.int64(8)):
+                                with T.block("matmul_intermediate_pad"):
+                                    v0 = T.axis.spatial(T.int64(1), T.int64(0))
+                                    v1 = T.axis.spatial(seq_len, i0_i1_fused_0 
* T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0)
+                                    v2 = T.axis.spatial(T.int64(12288), i2_0 * 
T.int64(256) + i2_1 * T.int64(8) + ax1)
+                                    T.where((i0_i1_fused_0 - (seq_len + 
T.int64(31)) // T.int64(32) < T.int64(0) or i0_i1_fused_0 == T.int64(0)) and 
i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0 < seq_len)
+                                    T.reads(matmul_intermediate_pad_local[v0, 
v1, v2])
+                                    T.writes(matmul_intermediate[v0, v1, v2])
+                                    matmul_intermediate[v0, v1, v2] = 
matmul_intermediate_pad_local[v0, v1, v2]
     # fmt: on
 
 

Reply via email to