This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new bd359685c7 [Unity][DLight] Update gemv rule (#15490)
bd359685c7 is described below

commit bd359685c7d2b6cb666dfa4eca1b6902b22d2a4f
Author: Bohan Hou <[email protected]>
AuthorDate: Tue Aug 8 05:02:24 2023 -0700

    [Unity][DLight] Update gemv rule (#15490)
---
 python/tvm/dlight/gpu/gemv.py        | 380 ++++++++++++++++++-------
 python/tvm/dlight/gpu/utils.py       |   2 +
 tests/python/dlight/test_gpu_gemv.py | 536 ++++++++++++++++++++++++-----------
 3 files changed, 653 insertions(+), 265 deletions(-)

diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py
index 13dee1cd54..b063883800 100644
--- a/python/tvm/dlight/gpu/gemv.py
+++ b/python/tvm/dlight/gpu/gemv.py
@@ -16,6 +16,7 @@
 # under the License.
 """A rule for GEMV and DecodeGEMV."""
 import re
+from functools import reduce
 from typing import List, Optional, Union
 
 from tvm import DataType, arith, ir, tir
@@ -124,6 +125,8 @@ def normalize(
             if c_loops:
                 return None
             loop, c_loop = sch.split(loop, factors=[None, 
split_expr.lower_factor])
+            # we expect the inner most dim to be grouped atm
+            assert not (is_reduction ^ is_inner_reduction)
             c_loops.append(c_loop)
         if is_reduction:
             r_loops.append(loop)
@@ -169,6 +172,10 @@ class GEMV(ScheduleRule):
             return None
 
         block_info = block_infos[0]
+        if len(block_info.iters) not in [2, 3]:
+            # either [B, S, R] = [B, S, R] * [B, R]
+            # or [S, R] = [S, R] * [R]
+            return None
         block = block_info.block_rv
         vector_input_buffers = is_gemv(sch, block_info)
         if vector_input_buffers is None:
@@ -179,14 +186,13 @@ class GEMV(ScheduleRule):
 
         # Step 2. Do the scheduling
         if is_inner_reduction:
-            # print(func)
             self.sch_inner_reduction(sch, target, block, vector_input_buffers, 
epilogue)
             return sch
         else:
             # TODO: Need to handle GEMV with KN layout
             return None
 
-    def sch_inner_reduction(  # pylint: disable=too-many-arguments
+    def sch_inner_reduction(  # pylint: disable=too-many-arguments, 
invalid-name, unused-argument
         self,
         sch: tir.Schedule,
         target: Target,
@@ -195,106 +201,282 @@ class GEMV(ScheduleRule):
         epilogue_info: Optional[BlockInfo],
     ):
         """Schedule the inner reduction block."""
-        # pylint: disable=invalid-name
-        _, s, r, _ = sch.get_loops(block)
-        # TODO: make it tunable
-        vec_bytes = 16 if target.kind.name == "cuda" else 8
-        unroll_number = 256 if target.kind.name == "cuda" else 64
+
+        def get_max_factor(n, factors):
+            factors = sorted(factors, reverse=True)
+            for factor in factors:
+                if n % factor == 0:
+                    return factor
+            return 1
+
+        def apply(
+            sch: tir.Schedule,
+            gemv,
+            TAG_S,
+            TAG_R,
+            TS,
+            TR,
+            TILE_S,
+            TILE_R,
+            VEC_LOAD,
+            VEC_C,
+            LOAD_V_SHARED,
+            LOAD_V_VEC,
+            UNROLL,
+        ):
+            # rfactor: reduce to tx * vec_c
+            _, s, r, c = sch.get_loops(block=gemv)
+            s = sch.fuse(_, s)
+            r = sch.fuse(r, c)
+            bx, ts, tile_s = sch.split(s, factors=[None, TS, TILE_S], 
preserve_unit_iters=True)
+            r, tr, tile_r_vec_n, vec_c = sch.split(
+                r, factors=[None, TR, TILE_R // VEC_C, VEC_C], 
preserve_unit_iters=True
+            )
+            sch.reorder(r, tile_r_vec_n, tr, vec_c)
+            tr_vec_c = sch.fuse(tr, vec_c)
+            rf = sch.rfactor(tr_vec_c, 0)
+
+            # rfactor: reduce to tx
+            bx, ts, tile_s, tr_vec_c = sch.get_loops(block=gemv)
+            tr, vec_c = sch.split(tr_vec_c, factors=[TR, None], 
preserve_unit_iters=True)
+            rf2 = sch.rfactor(tr, 0)
+
+            # bind, vectorize compute
+            bx, ts, tile_s, r, tile_r_vec_n, tr_vec_c = sch.get_loops(block=rf)
+            tr, vec_c = sch.split(tr_vec_c, factors=[TR, None], 
preserve_unit_iters=True)
+            sch.reorder(bx, ts, tr, r, tile_s, tile_r_vec_n, vec_c)
+            sch.bind(bx, "blockIdx.x")
+            sch.bind(ts, TAG_S)
+            sch.bind(tr, TAG_R)
+            sch.vectorize(vec_c)
+
+            shared_mem_usage = 0
+            for buf in vector_input_buffers:
+                buf_size = reduce(
+                    lambda x, y: x * y, buf.shape, 
tir.IntImm(buf.shape[0].dtype, 1)
+                ) * get_bytes(buf.dtype)
+                shared_mem_usage += buf_size
+            LOAD_V_SHARED = (
+                LOAD_V_SHARED
+                and isinstance(shared_mem_usage, tir.IntImm)
+                and shared_mem_usage.value <= 
target.max_shared_memory_per_block
+            )
+
+            # vectorize load A
+            # (TODO) this is now actually problematic since the number of 
loops is dependent on the
+            # number of dimensions of A_q
+            Aq_local = sch.cache_read(rf, read_buffer_index=1, 
storage_scope="local")
+            sch.compute_at(Aq_local, r, preserve_unit_loops=True)
+            s_local, r_local = sch.get_loops(block=Aq_local)[-2:]
+            s_local, vec_load = sch.split(
+                s_local, factors=[None, VEC_LOAD], preserve_unit_iters=True
+            )
+            sch.reorder(s_local, r_local, vec_load)  # either s_local or 
r_local should be 1
+            sch.vectorize(vec_load)
+
+            # load vector into shared memory, shape should be the whole vector
+            if LOAD_V_SHARED:
+                assert len(vector_input_buffers) == 1
+                V_shared = sch.cache_read(rf, read_buffer_index=0, 
storage_scope="shared")
+                sch.compute_at(V_shared, tr, preserve_unit_loops=True)
+                l = sch.get_loops(block=V_shared)[-1]
+                loop: tir.For = sch.get(l)
+                if isinstance(loop.extent, tir.IntImm):
+                    # avoid introducing predicates when vector length is too 
large
+                    vec_length = max(
+                        min(
+                            get_max_factor(
+                                (int)(loop.extent),
+                                [TS * TR * 1, TS * TR * 2, TS * TR * 4, TS * 
TR * 8],
+                            )
+                            // TS
+                            // TR,
+                            LOAD_V_VEC,
+                        ),
+                        1,
+                    )
+                else:
+                    vec_length = LOAD_V_VEC
+                if TAG_R == "threadIdx.x":
+                    _, ty, tx, vec = sch.split(
+                        l, factors=[None, TS, TR, vec_length], 
preserve_unit_iters=True
+                    )
+                else:
+                    _, ty, tx, vec = sch.split(
+                        l, factors=[None, TR, TS, vec_length], 
preserve_unit_iters=True
+                    )
+                sch.bind(ty, "threadIdx.y")
+                sch.bind(tx, "threadIdx.x")
+                sch.vectorize(vec)
+
+            # reduce tile_s * tr * vec to tile_s * tr
+            sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True)
+            tr, vec_c, *ts_tile_s = sch.get_loops(block=rf2)[1:]
+            ts_tile_s = sch.fuse(*ts_tile_s)
+            ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], 
preserve_unit_iters=True)
+            tile_s, vec_s = sch.split(
+                tile_s,
+                factors=[None, get_max_factor(TILE_S, [1, 2, 4, 8])],
+                preserve_unit_iters=True,
+            )
+            sch.reorder(ts, tr, tile_s, vec_s, vec_c)
+            sch.bind(ts, TAG_S)
+            sch.bind(tr, TAG_R)
+            sch.vectorize(vec_s)
+
+            # reduce tile_s * tr to tile_s
+            sch.reverse_compute_at(gemv, loop=bx, preserve_unit_loops=True)
+            tr, *ts_tile_s = sch.get_loops(block=gemv)[1:]
+            ts_tile_s = sch.fuse(*ts_tile_s)
+            ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], 
preserve_unit_iters=True)
+            sch.reorder(tile_s, ts, tr)
+            sch.bind(ts, TAG_S)
+            sch.bind(tr, TAG_R)
+
+            sch.decompose_reduction(rf, loop=sch.get_loops(block=rf)[3])
+            sch.decompose_reduction(rf2, loop=sch.get_loops(block=rf2)[-1])
+
+            sch.set_scope(rf, buffer_index=0, storage_scope="local")
+            sch.set_scope(rf2, buffer_index=0, storage_scope="local")
+
+            unroll_factor = UNROLL
+
+            sch.annotate(
+                block_or_loop=sch.get_loops(rf)[3],
+                ann_key="pragma_auto_unroll_max_step",
+                ann_val=unroll_factor,
+            )
+            sch.annotate(
+                block_or_loop=sch.get_loops(rf)[3], 
ann_key="pragma_unroll_explicit", ann_val=1
+            )
+
+            sch.annotate(
+                block_or_loop=sch.get_loops(rf2)[3],
+                ann_key="pragma_auto_unroll_max_step",
+                ann_val=unroll_factor,
+            )
+            sch.annotate(
+                block_or_loop=sch.get_loops(rf2)[3], 
ann_key="pragma_unroll_explicit", ann_val=1
+            )
+
+            if LOAD_V_SHARED:
+                sch.annotate(
+                    block_or_loop=sch.get_loops(V_shared)[-4],
+                    ann_key="pragma_unroll_explicit",
+                    ann_val=unroll_factor,
+                )
+                sch.annotate(
+                    block_or_loop=sch.get_loops(V_shared)[-4], 
ann_key="pragma_vectorize", ann_val=1
+                )
+
+            # Schedule epilogue
+            if epilogue_info is not None:
+                epilogue = epilogue_info.block_rv
+                if is_broadcast_epilogue(sch, block, epilogue):
+                    sch.reverse_compute_at(epilogue, bx)
+                    sch.set_scope(block, 0, "shared")
+                    _, _, *s = sch.get_loops(epilogue)  # pylint: 
disable=invalid-name
+                    _, tx = sch.split(sch.fuse(*s), factors=[None, TX])
+                    sch.bind(tx, "threadIdx.x")
+                else:
+                    sch.reverse_compute_at(epilogue, bx, 
preserve_unit_loops=True)
+                    ts_tile_s = sch.fuse(*sch.get_loops(epilogue)[1:])
+                    ts_tile_s = sch.get_loops(epilogue)[-1]
+                    ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], 
preserve_unit_iters=True)
+                    sch.bind(ts, TAG_S)
+                    sch.set_scope(block, 0, "local")
+            # pylint: enable=invalid-name
+            return sch
 
         def get_extent(loop_rv: tir.schedule.LoopRV):
             loop: tir.For = sch.get(loop_rv)
-            return loop.extent.value if isinstance(loop.extent, tir.IntImm) 
else 1
+            return loop.extent.value if isinstance(loop.extent, tir.IntImm) 
else loop.extent
 
         # Specify the `len_tx` and `len_ty` according to the loop extent
-        len_s, len_r = get_extent(s), get_extent(r)
-        if len_r >= 4096 and len_r % 128 == 0:
-            len_tx = 128
-        elif 1024 < len_r <= 2048 and len_r % 64 == 0:
-            len_tx = 64
-        else:
-            len_tx = 32
-
-        if len_s >= 4096:
-            len_ty = 8
-        else:
-            len_ty = min(len_s, 4)
-
-        # Use `split_k` to prevent too large shared memory usage
-        split_k: int = 4
-
-        _, tx = sch.split(r, [None, len_tx], preserve_unit_iters=True)
-        # Schedule the RF block
-        rf = sch.rfactor(tx, 0)
-        batch, bx, r, tx, _ = sch.get_loops(rf)
-        sch.reorder(bx, tx, r)
-        ro, ri = sch.split(r, [split_k, None], preserve_unit_iters=True)
-        bx, ty = sch.split(bx, [None, len_ty], preserve_unit_iters=True)
-
-        sch.bind(batch, "blockIdx.y")
-        sch.bind(bx, "blockIdx.x")
-        sch.bind(ty, "threadIdx.y")
-        sch.bind(tx, "threadIdx.x")
-        sch.annotate(ro, "pragma_auto_unroll_max_step", unroll_number)
-        sch.annotate(ro, "pragma_unroll_explicit", 1)
-
+        batch, s, r, c = sch.get_loops(block=block)
+        len_batch, len_s, len_r, len_c = (
+            get_extent(batch),
+            get_extent(s),
+            get_extent(r),
+            get_extent(c),
+        )
+        len_S = len_batch * len_s
+        len_R = len_r * len_c
+
+        TAG_S, TAG_R = "threadIdx.y", "threadIdx.x"
         if target.kind.name == "cuda":
-            # Cache read the vector
-            def cache_shared(index: int):
-                block: tir.Block = sch.get(rf)
-                type_bytes: int = get_bytes(block.reads[index].buffer.dtype)
-                cache = sch.cache_read(rf, index, "shared")
-                sch.compute_at(cache, ro, preserve_unit_loops=True)
-                fused = sch.fuse(*sch.get_loops(cache)[5:])
-                loop: tir.For = sch.get(fused)
-                vec_length = vec_bytes // type_bytes
-                if isinstance(loop.extent, tir.IntImm):
-                    # avoid introducing predicates when vector length is too 
large
-                    vec_length = min(loop.extent // len_ty // len_tx, 
vec_length)
-                _, _ty, _tx, _vec = sch.split(fused, [None, len_ty, len_tx, 
vec_length])
-                sch.bind(_ty, "threadIdx.y")
-                sch.bind(_tx, "threadIdx.x")
-                sch.vectorize(_vec)
-
-            def cache_local(index: int):
-                block: tir.Block = sch.get(rf)
-                type_bytes: int = get_bytes(block.reads[index].buffer.dtype)
-                vec_length = vec_bytes // type_bytes
-                cache = sch.cache_read(rf, index, "local")
-                sch.compute_at(cache, ri, preserve_unit_loops=True)
-                fused = sch.fuse(*sch.get_loops(cache)[6:])
-                loop: tir.For = sch.get(fused)
-                if isinstance(loop.extent, tir.IntImm) and loop.extent.value % 
vec_length == 0:
-                    _, _vec = sch.split(fused, [None, vec_length])
-                    sch.vectorize(_vec)
-                elif isinstance(loop.extent, tir.IntImm) and loop.extent.value 
< vec_length:
-                    sch.vectorize(fused)
-
-            for buffer in vector_input_buffers:
-                index = vector_input_buffers.index(buffer)
-                cache_shared(index)
-                cache_local(index)
-
-            # TODO: cache scale buffer in Decode-GEMV to shared memory
-
-        sch.set_scope(rf, 0, "local")
-        sch.decompose_reduction(rf, ro)
-        # Schedule the write back block
-        sch.reverse_compute_at(block, ty, preserve_unit_loops=True)
-        _, _, _, tx, *s = sch.get_loops(block)
-        s = sch.fuse(*s)
-        sch.reorder(s, tx)
-        sch.bind(tx, "threadIdx.x")
-        # Schedule epilogue
-        if epilogue_info is not None:
-            epilogue = epilogue_info.block_rv
-            if is_broadcast_epilogue(sch, block, epilogue):
-                sch.reverse_compute_at(epilogue, bx)
-                sch.set_scope(block, 0, "shared")
-                _, _, *s = sch.get_loops(epilogue)  # pylint: 
disable=invalid-name
-                _, tx = sch.split(sch.fuse(*s), factors=[None, len_tx])
-                sch.bind(tx, "threadIdx.x")
-            else:
-                # NOTE: Need to ensure tx_len == 32, so that can use `local` 
stage here
-                sch.reverse_compute_at(epilogue, ty)
-                sch.set_scope(block, 0, "local")
-        # pylint: enable=invalid-name
+            VEC_C = 4
+            LOAD_V_SHARED = True
+            LOAD_V_VEC = 8
+            UNROLL = 256
+            if isinstance(len_S, int):
+                if len_S > len_R:
+                    TS, TR = 4, 64
+                else:
+                    TS, TR = 16, 32
+        elif target.kind.name == "metal":
+            VEC_C = 2
+            LOAD_V_SHARED = True
+            LOAD_V_VEC = 4
+            UNROLL = 256
+            TS, TR = 64, 8
+        elif target.kind.name == "rocm":
+            VEC_C = 4
+            LOAD_V_SHARED = True
+            LOAD_V_VEC = 8
+            UNROLL = 256
+            if isinstance(len_S, int):
+                if len_S > len_R:
+                    TS, TR = 1, 128
+                else:
+                    TS, TR = 8, 64
+        elif target.kind.name == "opencl" and "android" in str(target.host):
+            TAG_S, TAG_R = "threadIdx.x", "threadIdx.y"
+            VEC_C = 8
+            LOAD_V_SHARED = False
+            LOAD_V_VEC = -1
+            UNROLL = 8
+            TS, TR = 2, 32
+        elif target.kind.name == "vulkan":
+            VEC_C = 4
+            LOAD_V_SHARED = True
+            LOAD_V_VEC = 4
+            UNROLL = 256
+            if isinstance(len_S, int):
+                if len_S > len_R:
+                    TS, TR = 4, 32
+                else:
+                    TS, TR = 16, 32
+        else:
+            VEC_C = 1
+            LOAD_V_SHARED = False
+            LOAD_V_VEC = -1
+            UNROLL = 64
+            TS, TR = 1, 64
+
+        if not isinstance(len_S, int):
+            TS, TR = 1, 64
+        TILE_S, TILE_R = (
+            1,
+            len_c
+            if len_c > 1
+            else max(get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8]) 
// TR, 1),
+        )
+        VEC_C = min(get_max_factor(TILE_R, [1, 2, 4, 8]), VEC_C)
+        VEC_LOAD = 1
+
+        return apply(
+            sch,
+            gemv=block,
+            TAG_S=TAG_S,
+            TAG_R=TAG_R,
+            TS=TS,
+            TR=TR,
+            TILE_S=TILE_S,
+            TILE_R=TILE_R,
+            VEC_LOAD=VEC_LOAD,
+            VEC_C=VEC_C,
+            LOAD_V_SHARED=LOAD_V_SHARED,
+            LOAD_V_VEC=LOAD_V_VEC,
+            UNROLL=UNROLL,
+        )
diff --git a/python/tvm/dlight/gpu/utils.py b/python/tvm/dlight/gpu/utils.py
index 4fcc762942..9f9a9c5ae4 100644
--- a/python/tvm/dlight/gpu/utils.py
+++ b/python/tvm/dlight/gpu/utils.py
@@ -51,6 +51,8 @@ def suggest_threads_per_block(
 ) -> List[int]:
     if target.kind.name == "cuda":
         threads = 256
+    elif target.kind.name == "rocm":
+        threads = 256
     else:
         threads = 64
     results: List[Optional[int]] = []
diff --git a/tests/python/dlight/test_gpu_gemv.py 
b/tests/python/dlight/test_gpu_gemv.py
index 6cb5cceb43..fd6850ac60 100644
--- a/tests/python/dlight/test_gpu_gemv.py
+++ b/tests/python/dlight/test_gpu_gemv.py
@@ -90,61 +90,92 @@ class TestGEMV(BaseBeforeAfter):
         var_compute_intermediate = T.match_buffer(p_output0, (1, 32, 1, n))
         # with T.block("root"):
         var_NT_matmul_intermediate_local = T.alloc_buffer((1, 32, 1, n), 
"float16", scope="local")
-        var_NT_matmul_intermediate_rf_local = T.alloc_buffer((32, 1, 32, 1, 
n), "float16", scope="local")
+        var_NT_matmul_intermediate_rf_local = T.alloc_buffer((128, 1, 32, 1, 
n), "float16", scope="local")
+        var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((64, 1, 32, 1, 
n), "float16", scope="local")
+        lv1638_local = T.alloc_buffer((1, 32, n, 128), "float16", 
scope="local")
         lv1637_shared = T.alloc_buffer((1, 32, 1, 128), "float16", 
scope="shared")
-        lv1637_shared_local = T.alloc_buffer((1, 32, 1, 128), "float16", 
scope="local")
-        for ax0_fused in T.thread_binding(32, thread="blockIdx.y"):
-            for ax1_fused_0 in T.thread_binding(n, thread="blockIdx.x"):
-                for ax1_fused_1 in T.thread_binding(1, thread="threadIdx.y"):
-                    for ax2_fused_1 in T.thread_binding(32, 
thread="threadIdx.x"):
-                        with T.block("NT_matmul_rf_init"):
-                            vax2_fused_1, v0 = T.axis.remap("SS", 
[ax2_fused_1, ax0_fused])
-                            v1 = T.axis.spatial(n, ax1_fused_0 + ax1_fused_1)
-                            T.reads()
-                            
T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1])
-                            var_NT_matmul_intermediate_rf_local[vax2_fused_1, 
0, v0, 0, v1] = T.float16(0)
-                        for ax2_fused_0_0 in T.serial(4, 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
-                            for ax0_ax1_ax2_ax3_fused_0 in range(1):
-                                for ax0_ax1_ax2_ax3_fused_1 in 
T.thread_binding(1, thread="threadIdx.y"):
-                                    for ax0_ax1_ax2_ax3_fused_2 in 
T.thread_binding(32, thread="threadIdx.x"):
-                                        for ax0_ax1_ax2_ax3_fused_3 in 
T.vectorized(1):
-                                            with T.block("lv1637_shared"):
-                                                v0 = T.axis.spatial(1, 0)
-                                                v1 = T.axis.spatial(32, 
ax0_fused)
-                                                v2 = T.axis.spatial(1, 0)
-                                                v3 = T.axis.spatial(128, 
ax2_fused_0_0 * 32 + ax0_ax1_ax2_ax3_fused_0 * 32 + ax0_ax1_ax2_ax3_fused_1 * 
32 + ax0_ax1_ax2_ax3_fused_2 + ax0_ax1_ax2_ax3_fused_3)
-                                                T.reads(lv1637[v0, v1, v2, v3])
-                                                T.writes(lv1637_shared[v0, v1, 
v2, v3])
-                                                lv1637_shared[v0, v1, v2, v3] 
= lv1637[v0, v1, v2, v3]
-                            for ax2_fused_0_1 in range(1):
-                                for ax0_ax1_ax2_ax3_fused in T.vectorized(1):
-                                    with T.block("lv1637_shared_local"):
-                                        v0 = T.axis.spatial(1, 0)
-                                        v1 = T.axis.spatial(32, ax0_fused)
-                                        v2 = T.axis.spatial(1, 0)
-                                        v3 = T.axis.spatial(128, ax2_fused_0_0 
* 32 + ax2_fused_1)
-                                        T.reads(lv1637_shared[v0, v1, v2, v3])
-                                        T.writes(lv1637_shared_local[v0, v1, 
v2, v3])
-                                        lv1637_shared_local[v0, v1, v2, v3] = 
lv1637_shared[v0, v1, v2, v3]
-                                for u in range(1):
-                                    with T.block("NT_matmul_rf_update"):
-                                        vax2_fused_1, v0 = T.axis.remap("SS", 
[ax2_fused_1, ax0_fused])
-                                        v1 = T.axis.spatial(n, ax1_fused_0 + 
ax1_fused_1)
-                                        vax2_fused_0 = T.axis.reduce(4, 
ax2_fused_0_0 + ax2_fused_0_1)
-                                        
T.reads(var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1], 
lv1637_shared_local[0, v0, 0, vax2_fused_0 * 32 + vax2_fused_1], lv1638[0, v0, 
v1, vax2_fused_0 * 32 + vax2_fused_1])
-                                        
T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1])
-                                        
var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1] = 
var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1] + 
lv1637_shared_local[0, v0, 0, vax2_fused_0 * 32 + vax2_fused_1] * lv1638[0, v0, 
v1, vax2_fused_0 * 32 + vax2_fused_1]
-                    for ax1_ax2_fused in range(1):
-                        for ax0 in T.thread_binding(32, thread="threadIdx.x"):
-                            with T.block("NT_matmul"):
-                                vax2_fused_1, v0, v1 = T.axis.remap("RSS", 
[ax0, ax0_fused, ax1_fused_0])
-                                
T.reads(var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1])
-                                T.writes(var_NT_matmul_intermediate_local[0, 
v0, 0, v1])
-                                with T.init():
-                                    var_NT_matmul_intermediate_local[0, v0, 0, 
v1] = T.float16(0)
-                                var_NT_matmul_intermediate_local[0, v0, 0, v1] 
= var_NT_matmul_intermediate_local[0, v0, 0, v1] + 
var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1]
+        for ax0_fused_ax1_fused_fused_0 in T.thread_binding(n * 32, 
thread="blockIdx.x"):
+            for ax0_fused_ax1_fused_fused_1 in T.thread_binding(1, 
thread="threadIdx.y"):
+                for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 in 
T.thread_binding(64, thread="threadIdx.x"):
+                    for ax0, ax1, ax2 in T.grid(1, 1, 1):
+                        for ax3_0 in T.serial(1, 
annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
+                            for ax3_1 in T.thread_binding(1, 
thread="threadIdx.y"):
+                                for ax3_2 in T.thread_binding(64, 
thread="threadIdx.x"):
+                                    for ax3_3 in T.vectorized(2):
+                                        with T.block("lv1637_shared"):
+                                            v0 = T.axis.spatial(1, ax0)
+                                            v1 = T.axis.spatial(32, 
ax0_fused_ax1_fused_fused_0 // n + ax1)
+                                            v2 = T.axis.spatial(1, ax2)
+                                            v3 = T.axis.spatial(128, ax3_0 * 
128 + ax3_1 * 128 + ax3_2 * 2 + ax3_3)
+                                            T.reads(lv1637[v0, v1, v2, v3])
+                                            T.writes(lv1637_shared[v0, v1, v2, 
v3])
+                                            lv1637_shared[v0, v1, v2, v3] = 
lv1637[v0, v1, v2, v3]
+                    for ax0_fused_ax1_fused_fused_2_init in range(1):
+                        for 
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init in T.vectorized(2):
+                            with T.block("NT_matmul_rf_init"):
+                                vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused 
= T.axis.spatial(128, ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + 
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init)
+                                v0 = T.axis.spatial(32, 
(ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 + 
ax0_fused_ax1_fused_fused_2_init) // n)
+                                v1 = T.axis.spatial(n, 
(ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 + 
ax0_fused_ax1_fused_fused_2_init) % n)
+                                T.reads()
+                                
T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
 0, v0, 0, v1])
+                                
var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
 0, v0, 0, v1] = T.float16(0)
+                    for ax2_fused_u_fused_0 in T.serial(1, 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                        for ax0, ax1, ax2_0, ax3 in T.grid(1, 1, 1, 2):
+                            for ax2_1 in T.vectorized(1):
+                                with T.block("lv1638_local"):
+                                    v0 = T.axis.spatial(1, ax0)
+                                    v1 = T.axis.spatial(32, 
ax0_fused_ax1_fused_fused_0 // n + ax1)
+                                    v2 = T.axis.spatial(n, 
ax0_fused_ax1_fused_fused_0 % n + ax2_0 + ax2_1)
+                                    v3 = T.axis.spatial(128, 
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + ax3)
+                                    T.reads(lv1638[v0, v1, v2, v3])
+                                    T.writes(lv1638_local[v0, v1, v2, v3])
+                                    lv1638_local[v0, v1, v2, v3] = lv1638[v0, 
v1, v2, v3]
+                        for ax0_fused_ax1_fused_fused_2, ax2_fused_u_fused_2 
in T.grid(1, 1):
+                            for 
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 in T.vectorized(2):
+                                with T.block("NT_matmul_rf_update"):
+                                    
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(128, 
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + 
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1)
+                                    v0 = T.axis.spatial(32, 
(ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 + 
ax0_fused_ax1_fused_fused_2) // n)
+                                    v1 = T.axis.spatial(n, 
(ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 + 
ax0_fused_ax1_fused_fused_2) % n)
+                                    vax2_fused_u_fused_2, vax2_fused_u_fused_0 
= T.axis.remap("RR", [ax2_fused_u_fused_2, ax2_fused_u_fused_0])
+                                    
T.reads(var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
 0, v0, 0, v1], lv1637_shared[0, v0, 0, vax2_fused_u_fused_0 * 128 + 
vax2_fused_u_fused_2 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused], 
lv1638_local[0, v0, v1, vax2_fused_u_fused_0 * 128 + vax2_fused_u_fused_2 * 2 + 
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused])
+                                    
T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
 0, v0, 0, v1])
+                                    
var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
 0, v0, 0, v1] = 
var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
 0, v0, 0, v1] + lv1637_shared[0, v0, 0, vax2_fused_u_fused_0 * 128 + 
vax2_fused_u_fused_2 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused] * 
lv1638_local[0, v0, v1, vax2_fused_u_fused_0 * 128 + vax2_fused_u_fused_2 * 2 + 
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused]
+            for ax2_ax3_fused_0 in T.thread_binding(1, thread="threadIdx.y"):
+                for ax0 in T.thread_binding(64, thread="threadIdx.x"):
+                    for ax2_ax3_fused_1_0 in T.serial(1, 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                        for ax2_ax3_fused_1_1 in T.vectorized(1):
+                            with T.block("NT_matmul_rf_init"):
+                                
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.spatial(64, ax0)
+                                v0 = T.axis.spatial(32, 
ax0_fused_ax1_fused_fused_0 // n)
+                                v1 = T.axis.spatial(n, 
ax0_fused_ax1_fused_fused_0 % n)
+                                T.reads()
+                                
T.writes(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
 0, v0, 0, v1])
+                                
var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
 0, v0, 0, v1] = T.float16(0)
+                            for ax1 in range(2):
+                                with T.block("NT_matmul_rf_update"):
+                                    
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, 
ax1])
+                                    v0 = T.axis.spatial(32, 
ax0_fused_ax1_fused_fused_0 // n)
+                                    v1 = T.axis.spatial(n, 
ax0_fused_ax1_fused_fused_0 % n)
+                                    
T.reads(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
 0, v0, 0, v1], 
var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0
 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, 0, v0, 0, v1])
+                                    
T.writes(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
 0, v0, 0, v1])
+                                    
var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
 0, v0, 0, v1] = 
var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
 0, v0, 0, v1] + 
var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0
 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, 0, v0, 0, v1]
+            for ax1_ax2_fused_1 in range(1):
+                for ax1_ax2_fused_0 in T.thread_binding(1, 
thread="threadIdx.y"):
+                    for ax0 in T.thread_binding(64, thread="threadIdx.x"):
+                        with T.block("NT_matmul"):
+                            vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = 
T.axis.reduce(64, ax0)
+                            v0 = T.axis.spatial(32, 
ax0_fused_ax1_fused_fused_0 // n)
+                            v1 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 
% n)
+                            
T.reads(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
 0, v0, 0, v1])
+                            T.writes(var_NT_matmul_intermediate_local[0, v0, 
0, v1])
+                            with T.init():
+                                var_NT_matmul_intermediate_local[0, v0, 0, v1] 
= T.float16(0)
+                            var_NT_matmul_intermediate_local[0, v0, 0, v1] = 
var_NT_matmul_intermediate_local[0, v0, 0, v1] + 
var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
 0, v0, 0, v1]
+            for ax0_ax1_fused_0 in T.thread_binding(1, thread="threadIdx.y"):
+                for ax0_ax1_fused_1 in range(1):
                     with T.block("compute"):
-                        v0, v1 = T.axis.remap("SS", [ax0_fused, ax1_fused_0])
+                        v0 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // 
n)
+                        v1 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n)
                         T.reads(var_NT_matmul_intermediate_local[0, v0, 0, 
v1], lv1614[0, 0, 0, v1])
                         T.writes(var_compute_intermediate[0, v0, 0, v1])
                         var_compute_intermediate[0, v0, 0, v1] = 
T.Cast("float32", T.min(T.max(var_NT_matmul_intermediate_local[0, v0, 0, v1] * 
T.float16(0.088397790055248615), T.float16(-65504)), lv1614[0, 0, 0, v1]))
@@ -152,10 +183,10 @@ class TestGEMV(BaseBeforeAfter):
     # fmt: on
 
 
-class TestDecodeGEMV1(BaseBeforeAfter):
+def test_decode_gemv1():
     # fmt: off
 
-    @T.prim_func
+    @T.prim_func(private=True)
     def before(lv571: T.Buffer((22016, 512), "uint32"), lv572: 
T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), 
var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")):
         T.func_attr({"tir.noalias": T.bool(True)})
         # with T.block("root"):
@@ -175,72 +206,95 @@ class TestDecodeGEMV1(BaseBeforeAfter):
                     var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
                 var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = 
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv1654[v_i0, v_i1, v_k] * 
p_output0_intermediate[v_i2, v_k]
 
-    @T.prim_func
+    @T.prim_func(private=True)
     def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: 
T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), 
var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")):
         T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
         # with T.block("root"):
-        var_NT_matmul_intermediate_rf_local = T.alloc_buffer((32, 1, 1, 
22016), "float16", scope="local")
+        var_NT_matmul_intermediate_rf_local = T.alloc_buffer((256, 1, 1, 
22016), "float16", scope="local")
+        var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((64, 1, 1, 
22016), "float16", scope="local")
+        lv571_local = T.alloc_buffer((22016, 512), "uint32", scope="local")
         lv1654_shared = T.alloc_buffer((1, 1, 4096), "float16", scope="shared")
-        lv1654_shared_local = T.alloc_buffer((1, 1, 4096), "float16", 
scope="local")
-        for u_fused in T.thread_binding(1, thread="blockIdx.y"):
-            for ax0_fused_0 in T.thread_binding(2752, thread="blockIdx.x"):
-                for ax0_fused_1 in T.thread_binding(8, thread="threadIdx.y"):
-                    for ax1_0_fused_1 in T.thread_binding(32, 
thread="threadIdx.x"):
-                        with T.block("NT_matmul_rf_init"):
-                            vax1_0_fused_1 = T.axis.spatial(32, ax1_0_fused_1)
-                            v0 = T.axis.spatial(22016, ax0_fused_0 * 8 + 
ax0_fused_1)
-                            T.reads()
-                            
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0])
-                            
var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = T.float16(0)
-                        for ax1_0_fused_0_0 in T.serial(4, 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
-                            for ax0_ax1_ax2_fused_0 in range(1):
-                                for ax0_ax1_ax2_fused_1 in T.thread_binding(8, 
thread="threadIdx.y"):
-                                    for ax0_ax1_ax2_fused_2 in 
T.thread_binding(32, thread="threadIdx.x"):
-                                        for ax0_ax1_ax2_fused_3 in 
T.vectorized(4):
-                                            with T.block("lv1654_shared"):
-                                                v0 = T.axis.spatial(1, 0)
-                                                v1 = T.axis.spatial(1, 0)
-                                                v2 = T.axis.spatial(4096, 
ax1_0_fused_0_0 * 1024 + ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 * 128 
+ ax0_ax1_ax2_fused_2 * 4 + ax0_ax1_ax2_fused_3)
-                                                T.reads(lv1654[v0, v1, v2])
-                                                T.writes(lv1654_shared[v0, v1, 
v2])
-                                                lv1654_shared[v0, v1, v2] = 
lv1654[v0, v1, v2]
-                            for ax1_0_fused_0_1 in range(4):
-                                for ax0_ax1_ax2_fused_0 in range(1):
-                                    for ax0_ax1_ax2_fused_1 in T.vectorized(8):
-                                        with T.block("lv1654_shared_local"):
-                                            v0 = T.axis.spatial(1, 0)
-                                            v1 = T.axis.spatial(1, 0)
-                                            v2 = T.axis.spatial(4096, 
ax1_0_fused_0_0 * 1024 + ax1_0_fused_0_1 * 256 + ax1_0_fused_1 * 8 + 
ax0_ax1_ax2_fused_0 * 8 + ax0_ax1_ax2_fused_1)
-                                            T.reads(lv1654_shared[v0, v1, v2])
-                                            T.writes(lv1654_shared_local[v0, 
v1, v2])
-                                            lv1654_shared_local[v0, v1, v2] = 
lv1654_shared[v0, v1, v2]
-                                for ax1_1 in range(8):
-                                    with T.block("NT_matmul_rf_update"):
-                                        vax1_0_fused_1 = T.axis.spatial(32, 
ax1_0_fused_1)
-                                        v0 = T.axis.spatial(22016, ax0_fused_0 
* 8 + ax0_fused_1)
-                                        vax1_0_fused_0 = T.axis.reduce(16, 
ax1_0_fused_0_0 * 4 + ax1_0_fused_0_1)
-                                        vax1_1 = T.axis.reduce(8, ax1_1)
-                                        
T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0], 
lv1654_shared_local[0, 0, vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1], 
lv571[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 8], lv572[v0, 
(vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 32])
-                                        
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0])
-                                        
var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = 
var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] + 
lv1654_shared_local[0, 0, vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1] * 
((T.Cast("float16", T.bitwise_and(T.shift_right(lv571[v0, (vax1_0_fused_0 * 256 
+ vax1_0_fused_1 * 8 + vax1_1) // 8], T.Cast("uint32", (vax1_0_fused_0 * 256 + 
vax1_0_fused_1 * 8 + vax1_1) % 8) * T.uint32(4)), T.uint32(15))) - 
T.float16(7)) * lv [...]
-                    for ax1_fused in range(1):
-                        for ax0 in T.thread_binding(32, thread="threadIdx.x"):
-                            with T.block("NT_matmul"):
-                                vax1_0_fused_1 = T.axis.reduce(32, ax0)
-                                v0 = T.axis.spatial(22016, ax0_fused_0 * 8 + 
ax0_fused_1)
-                                
T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0])
-                                T.writes(var_NT_matmul_intermediate[0, 0, v0])
-                                with T.init():
-                                    var_NT_matmul_intermediate[0, 0, v0] = 
T.float16(0)
-                                var_NT_matmul_intermediate[0, 0, v0] = 
var_NT_matmul_intermediate[0, 0, v0] + 
var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0]
+        for u_fused_ax0_fused_fused_0 in T.thread_binding(5504, 
thread="blockIdx.x"):
+            for u_fused_ax0_fused_fused_1 in T.thread_binding(4, 
thread="threadIdx.y"):
+                for 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in 
T.thread_binding(64, thread="threadIdx.x"):
+                    for ax0, ax1 in T.grid(1, 1):
+                        for ax2_0 in T.serial(2, 
annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
+                            for ax2_1 in T.thread_binding(4, 
thread="threadIdx.y"):
+                                for ax2_2 in T.thread_binding(64, 
thread="threadIdx.x"):
+                                    for ax2_3 in T.vectorized(8):
+                                        with T.block("lv1654_shared"):
+                                            v0, v1 = T.axis.remap("SS", [ax0, 
ax1])
+                                            v2 = T.axis.spatial(4096, ax2_0 * 
2048 + ax2_1 * 512 + ax2_2 * 8 + ax2_3)
+                                            T.reads(lv1654[v0, v1, v2])
+                                            T.writes(lv1654_shared[v0, v1, v2])
+                                            lv1654_shared[v0, v1, v2] = 
lv1654[v0, v1, v2]
+                    for u_fused_ax0_fused_fused_2_init in range(1):
+                        for 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in 
T.vectorized(4):
+                            with T.block("NT_matmul_rf_init"):
+                                
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = 
T.axis.spatial(256, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 
* 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init)
+                                v0 = T.axis.spatial(22016, 
u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + 
u_fused_ax0_fused_fused_2_init)
+                                T.reads()
+                                
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 0, 0, v0])
+                                
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 0, 0, v0] = T.float16(0)
+                    for ax1_0_fused_ax1_1_fused_0 in T.serial(8, 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                        for ax0_0, ax1 in T.grid(1, 1):
+                            for ax0_1 in T.vectorized(1):
+                                with T.block("lv571_local"):
+                                    v0 = T.axis.spatial(22016, 
u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
+                                    v1 = T.axis.spatial(512, 
ax1_0_fused_ax1_1_fused_0 * 64 + 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
+                                    T.reads(lv571[v0, v1])
+                                    T.writes(lv571_local[v0, v1])
+                                    lv571_local[v0, v1] = lv571[v0, v1]
+                        for u_fused_ax0_fused_fused_2, 
ax1_0_fused_ax1_1_fused_2 in T.grid(1, 2):
+                            for 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(4):
+                                with T.block("NT_matmul_rf_update"):
+                                    
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = 
T.axis.spatial(256, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 
* 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1)
+                                    v0 = T.axis.spatial(22016, 
u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + 
u_fused_ax0_fused_fused_2)
+                                    vax1_0_fused_ax1_1_fused_0, 
vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, 
ax1_0_fused_ax1_1_fused_2])
+                                    
T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 0, 0, v0], lv1654_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 512 + 
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + 
vax1_0_fused_ax1_1_fused_2 * 4 + 
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4], 
lv571_local[v0, vax1_0_fused_ax1_1_fused_0 * 64 + 
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 + va [...]
+                                    
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 0, 0, v0])
+                                    
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 0, 0, v0] = 
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 0, 0, v0] + lv1654_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 512 + 
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + 
vax1_0_fused_ax1_1_fused_2 * 4 + 
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4] * 
((T.Cast("float1 [...]
+            for ax2_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
+                for ax0 in T.thread_binding(64, thread="threadIdx.x"):
+                    for ax2_fused_1_0 in T.serial(1, 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                        for ax2_fused_1_1 in T.vectorized(1):
+                            with T.block("NT_matmul_rf_init"):
+                                
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = 
T.axis.spatial(64, ax0)
+                                v0 = T.axis.spatial(22016, 
u_fused_ax0_fused_fused_0 * 4 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
+                                T.reads()
+                                
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 0, 0, v0])
+                                
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 0, 0, v0] = T.float16(0)
+                            for ax1 in range(4):
+                                with T.block("NT_matmul_rf_update"):
+                                    
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = 
T.axis.remap("SR", [ax0, ax1])
+                                    v0 = T.axis.spatial(22016, 
u_fused_ax0_fused_fused_0 * 4 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
+                                    
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 0, 0, v0], 
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0
 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0])
+                                    
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 0, 0, v0])
+                                    
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 0, 0, v0] = 
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 0, 0, v0] + 
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0
 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]
+            for ax1_fused_1 in range(1):
+                for ax1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
+                    for ax0 in T.thread_binding(64, thread="threadIdx.x"):
+                        with T.block("NT_matmul"):
+                            
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = 
T.axis.reduce(64, ax0)
+                            v0 = T.axis.spatial(22016, 
u_fused_ax0_fused_fused_0 * 4 + ax1_fused_0 + ax1_fused_1)
+                            
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 0, 0, v0])
+                            T.writes(var_NT_matmul_intermediate[0, 0, v0])
+                            with T.init():
+                                var_NT_matmul_intermediate[0, 0, v0] = 
T.float16(0)
+                            var_NT_matmul_intermediate[0, 0, v0] = 
var_NT_matmul_intermediate[0, 0, v0] + 
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 0, 0, v0]
 
     # fmt: on
 
+    mod = tvm.IRModule({"main": before})
+    with Target("nvidia/geforce-rtx-3090-ti"):
+        mod = dl.ApplyDefaultSchedule(dl.gpu.GEMV())(mod)
+    tvm.ir.assert_structural_equal(mod["main"], expected)
 
-class TestDecodeGEMV2(BaseBeforeAfter):
+
+def test_decode_gemv2():
     # fmt: off
 
-    @T.prim_func
+    @T.prim_func(private=True)
     def before(lv771: T.Buffer((32000, 512), "uint32"), lv772: 
T.Buffer((32000, 128), "float16"), lv3216: T.Buffer((1, 1, 4096), "float16"), 
p_output0_intermediate: T.Buffer((1, 1, 32000), "float32")):
         T.func_attr({"tir.noalias": T.bool(True)})
         # with T.block("root"):
@@ -267,73 +321,223 @@ class TestDecodeGEMV2(BaseBeforeAfter):
                 T.writes(p_output0_intermediate[v_i0, v_i1, v_i2])
                 p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", 
var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
 
-    @T.prim_func
+    @T.prim_func(private=True)
     def expected(lv771: T.Buffer((32000, 512), "uint32"), lv772: 
T.Buffer((32000, 128), "float16"), lv3216: T.Buffer((1, 1, 4096), "float16"), 
p_output0_intermediate: T.Buffer((1, 1, 32000), "float32")):
         T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
         # with T.block("root"):
         var_NT_matmul_intermediate_local = T.alloc_buffer((1, 1, 32000), 
"float16", scope="local")
-        var_NT_matmul_intermediate_rf_local = T.alloc_buffer((32, 1, 1, 
32000), "float16", scope="local")
+        var_NT_matmul_intermediate_rf_local = T.alloc_buffer((256, 1, 1, 
32000), "float16", scope="local")
+        var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((64, 1, 1, 
32000), "float16", scope="local")
+        lv771_local = T.alloc_buffer((32000, 512), "uint32", scope="local")
         lv3216_shared = T.alloc_buffer((1, 1, 4096), "float16", scope="shared")
-        lv3216_shared_local = T.alloc_buffer((1, 1, 4096), "float16", 
scope="local")
-        for u_fused in T.thread_binding(1, thread="blockIdx.y"):
-            for ax0_fused_0 in T.thread_binding(4000, thread="blockIdx.x"):
-                for ax0_fused_1 in T.thread_binding(8, thread="threadIdx.y"):
-                    for ax1_0_fused_1 in T.thread_binding(32, 
thread="threadIdx.x"):
-                        with T.block("NT_matmul_rf_init"):
-                            vax1_0_fused_1 = T.axis.spatial(32, ax1_0_fused_1)
-                            v0 = T.axis.spatial(32000, ax0_fused_0 * 8 + 
ax0_fused_1)
-                            T.reads()
-                            
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0])
-                            
var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = T.float16(0)
-                        for ax1_0_fused_0_0 in T.serial(4, 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
-                            for ax0_ax1_ax2_fused_0 in range(1):
-                                for ax0_ax1_ax2_fused_1 in T.thread_binding(8, 
thread="threadIdx.y"):
-                                    for ax0_ax1_ax2_fused_2 in 
T.thread_binding(32, thread="threadIdx.x"):
-                                        for ax0_ax1_ax2_fused_3 in 
T.vectorized(4):
-                                            with T.block("lv3216_shared"):
-                                                v0 = T.axis.spatial(1, 0)
-                                                v1 = T.axis.spatial(1, 0)
-                                                v2 = T.axis.spatial(4096, 
ax1_0_fused_0_0 * 1024 + ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 * 128 
+ ax0_ax1_ax2_fused_2 * 4 + ax0_ax1_ax2_fused_3)
-                                                T.reads(lv3216[v0, v1, v2])
-                                                T.writes(lv3216_shared[v0, v1, 
v2])
-                                                lv3216_shared[v0, v1, v2] = 
lv3216[v0, v1, v2]
-                            for ax1_0_fused_0_1 in range(4):
-                                for ax0_ax1_ax2_fused_0 in range(1):
-                                    for ax0_ax1_ax2_fused_1 in T.vectorized(8):
-                                        with T.block("lv3216_shared_local"):
-                                            v0 = T.axis.spatial(1, 0)
-                                            v1 = T.axis.spatial(1, 0)
-                                            v2 = T.axis.spatial(4096, 
ax1_0_fused_0_0 * 1024 + ax1_0_fused_0_1 * 256 + ax1_0_fused_1 * 8 + 
ax0_ax1_ax2_fused_0 * 8 + ax0_ax1_ax2_fused_1)
-                                            T.reads(lv3216_shared[v0, v1, v2])
-                                            T.writes(lv3216_shared_local[v0, 
v1, v2])
-                                            lv3216_shared_local[v0, v1, v2] = 
lv3216_shared[v0, v1, v2]
-                                for ax1_1 in range(8):
-                                    with T.block("NT_matmul_rf_update"):
-                                        vax1_0_fused_1 = T.axis.spatial(32, 
ax1_0_fused_1)
-                                        v0 = T.axis.spatial(32000, ax0_fused_0 
* 8 + ax0_fused_1)
-                                        vax1_0_fused_0 = T.axis.reduce(16, 
ax1_0_fused_0_0 * 4 + ax1_0_fused_0_1)
-                                        vax1_1 = T.axis.reduce(8, ax1_1)
-                                        
T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0], 
lv3216_shared_local[0, 0, vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1], 
lv771[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 8], lv772[v0, 
(vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 32])
-                                        
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0])
-                                        
var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = 
var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] + 
lv3216_shared_local[0, 0, vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1] * 
((T.Cast("float16", T.bitwise_and(T.shift_right(lv771[v0, (vax1_0_fused_0 * 256 
+ vax1_0_fused_1 * 8 + vax1_1) // 8], T.Cast("uint32", (vax1_0_fused_0 * 256 + 
vax1_0_fused_1 * 8 + vax1_1) % 8) * T.uint32(4)), T.uint32(15))) - 
T.float16(7)) * lv [...]
-                    for ax1_fused in range(1):
-                        for ax0 in T.thread_binding(32, thread="threadIdx.x"):
-                            with T.block("NT_matmul"):
-                                vax1_0_fused_1 = T.axis.reduce(32, ax0)
-                                v0 = T.axis.spatial(32000, ax0_fused_0 * 8 + 
ax0_fused_1)
-                                
T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0])
-                                T.writes(var_NT_matmul_intermediate_local[0, 
0, v0])
-                                with T.init():
-                                    var_NT_matmul_intermediate_local[0, 0, v0] 
= T.float16(0)
-                                var_NT_matmul_intermediate_local[0, 0, v0] = 
var_NT_matmul_intermediate_local[0, 0, v0] + 
var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0]
+        for u_fused_ax0_fused_fused_0 in T.thread_binding(8000, 
thread="blockIdx.x"):
+            for u_fused_ax0_fused_fused_1 in T.thread_binding(4, 
thread="threadIdx.y"):
+                for 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in 
T.thread_binding(64, thread="threadIdx.x"):
+                    for ax0, ax1 in T.grid(1, 1):
+                        for ax2_0 in T.serial(2, 
annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
+                            for ax2_1 in T.thread_binding(4, 
thread="threadIdx.y"):
+                                for ax2_2 in T.thread_binding(64, 
thread="threadIdx.x"):
+                                    for ax2_3 in T.vectorized(8):
+                                        with T.block("lv3216_shared"):
+                                            v0, v1 = T.axis.remap("SS", [ax0, 
ax1])
+                                            v2 = T.axis.spatial(4096, ax2_0 * 
2048 + ax2_1 * 512 + ax2_2 * 8 + ax2_3)
+                                            T.reads(lv3216[v0, v1, v2])
+                                            T.writes(lv3216_shared[v0, v1, v2])
+                                            lv3216_shared[v0, v1, v2] = 
lv3216[v0, v1, v2]
+                    for u_fused_ax0_fused_fused_2_init in range(1):
+                        for 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in 
T.vectorized(4):
+                            with T.block("NT_matmul_rf_init"):
+                                
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = 
T.axis.spatial(256, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 
* 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init)
+                                v0 = T.axis.spatial(32000, 
u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + 
u_fused_ax0_fused_fused_2_init)
+                                T.reads()
+                                
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 0, 0, v0])
+                                
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 0, 0, v0] = T.float16(0)
+                    for ax1_0_fused_ax1_1_fused_0 in T.serial(8, 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                        for ax0_0, ax1 in T.grid(1, 1):
+                            for ax0_1 in T.vectorized(1):
+                                with T.block("lv771_local"):
+                                    v0 = T.axis.spatial(32000, 
u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
+                                    v1 = T.axis.spatial(512, 
ax1_0_fused_ax1_1_fused_0 * 64 + 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
+                                    T.reads(lv771[v0, v1])
+                                    T.writes(lv771_local[v0, v1])
+                                    lv771_local[v0, v1] = lv771[v0, v1]
+                        for u_fused_ax0_fused_fused_2, 
ax1_0_fused_ax1_1_fused_2 in T.grid(1, 2):
+                            for 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(4):
+                                with T.block("NT_matmul_rf_update"):
+                                    
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = 
T.axis.spatial(256, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 
* 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1)
+                                    v0 = T.axis.spatial(32000, 
u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + 
u_fused_ax0_fused_fused_2)
+                                    vax1_0_fused_ax1_1_fused_0, 
vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, 
ax1_0_fused_ax1_1_fused_2])
+                                    
T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 0, 0, v0], lv3216_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 512 + 
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + 
vax1_0_fused_ax1_1_fused_2 * 4 + 
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4], 
lv771_local[v0, vax1_0_fused_ax1_1_fused_0 * 64 + 
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 + va [...]
+                                    
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 0, 0, v0])
+                                    
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 0, 0, v0] = 
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 0, 0, v0] + lv3216_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 512 + 
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + 
vax1_0_fused_ax1_1_fused_2 * 4 + 
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4] * 
((T.Cast("float1 [...]
+            for ax2_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
+                for ax0 in T.thread_binding(64, thread="threadIdx.x"):
+                    for ax2_fused_1_0 in T.serial(1, 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                        for ax2_fused_1_1 in T.vectorized(1):
+                            with T.block("NT_matmul_rf_init"):
+                                
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = 
T.axis.spatial(64, ax0)
+                                v0 = T.axis.spatial(32000, 
u_fused_ax0_fused_fused_0 * 4 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
+                                T.reads()
+                                
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 0, 0, v0])
+                                
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 0, 0, v0] = T.float16(0)
+                            for ax1 in range(4):
+                                with T.block("NT_matmul_rf_update"):
+                                    
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = 
T.axis.remap("SR", [ax0, ax1])
+                                    v0 = T.axis.spatial(32000, 
u_fused_ax0_fused_fused_0 * 4 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
+                                    
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 0, 0, v0], 
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0
 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0])
+                                    
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 0, 0, v0])
+                                    
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 0, 0, v0] = 
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 0, 0, v0] + 
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0
 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]
+            for ax1_fused_1 in range(1):
+                for ax1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
+                    for ax0 in T.thread_binding(64, thread="threadIdx.x"):
+                        with T.block("NT_matmul"):
+                            
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = 
T.axis.reduce(64, ax0)
+                            v0 = T.axis.spatial(32000, 
u_fused_ax0_fused_fused_0 * 4 + ax1_fused_0 + ax1_fused_1)
+                            
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 0, 0, v0])
+                            T.writes(var_NT_matmul_intermediate_local[0, 0, 
v0])
+                            with T.init():
+                                var_NT_matmul_intermediate_local[0, 0, v0] = 
T.float16(0)
+                            var_NT_matmul_intermediate_local[0, 0, v0] = 
var_NT_matmul_intermediate_local[0, 0, v0] + 
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 0, 0, v0]
+            for ax0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
+                for ax0_fused_1 in range(1):
                     with T.block("compute"):
-                        v0 = T.axis.spatial(32000, ax0_fused_0 * 8 + 
ax0_fused_1)
+                        v0 = T.axis.spatial(32000, u_fused_ax0_fused_fused_0 * 
4 + ax0_fused_0 + ax0_fused_1)
                         T.reads(var_NT_matmul_intermediate_local[0, 0, v0])
                         T.writes(p_output0_intermediate[0, 0, v0])
                         p_output0_intermediate[0, 0, v0] = T.Cast("float32", 
var_NT_matmul_intermediate_local[0, 0, v0])
 
     # fmt: on
 
+    mod = tvm.IRModule({"main": before})
+    with Target("nvidia/geforce-rtx-3090-ti"):
+        mod = dl.ApplyDefaultSchedule(dl.gpu.GEMV())(mod)
+    tvm.ir.assert_structural_equal(mod["main"], expected)
+
+
+def test_decode_gemv3():
+    # fmt: off
+
+    @T.prim_func(private=True)
+    def before(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), 
lv576: T.Buffer((T.int64(4096), T.int64(344)), "float16"), lv574: 
T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv570: 
T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), 
p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), 
"float16")):
+        T.func_attr({"tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        p_output0_intermediate_1 = T.alloc_buffer((T.int64(4096), 
T.int64(11008)), "float16")
+        var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), 
T.int64(4096)), "float16")
+        for i, j in T.grid(T.int64(4096), T.int64(11008)):
+            with T.block("decode"):
+                v_i, v_j = T.axis.remap("SS", [i, j])
+                T.reads(lv575[v_i, v_j // T.int64(8)], lv576[v_i, v_j // 
T.int64(32)])
+                T.writes(p_output0_intermediate_1[v_i, v_j])
+                p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", 
T.bitwise_and(T.shift_right(lv575[v_i, v_j // T.int64(8)], T.Cast("uint32", v_j 
% T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv576[v_i, v_j 
// T.int64(32)]
+        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), 
T.int64(11008)):
+            with T.block("NT_matmul"):
+                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+                T.reads(lv574[v_i0, v_i1, v_k], p_output0_intermediate_1[v_i2, 
v_k])
+                T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
+                with T.init():
+                    var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
+                var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = 
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv574[v_i0, v_i1, v_k] * 
p_output0_intermediate_1[v_i2, v_k]
+        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
+            with T.block("T_add"):
+                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                T.reads(lv570[v_ax0, v_ax1, v_ax2], 
var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2])
+                T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
+                p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv570[v_ax0, 
v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]
+
+    @T.prim_func(private=True)
+    def expected(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), 
lv576: T.Buffer((T.int64(4096), T.int64(344)), "float16"), lv574: 
T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv570: 
T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), 
p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), 
"float16")):
+        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        var_NT_matmul_intermediate_local = T.alloc_buffer((T.int64(1), 
T.int64(1), T.int64(4096)), "float16", scope="local")
+        var_NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(128), 
T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="local")
+        var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((T.int64(32), 
T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="local")
+        lv575_local = T.alloc_buffer((T.int64(4096), T.int64(1376)), "uint32", 
scope="local")
+        lv574_shared = T.alloc_buffer((T.int64(1), T.int64(1), 
T.int64(11008)), "float16", scope="shared")
+        for u_fused_ax0_fused_fused_0 in T.thread_binding(T.int64(256), 
thread="blockIdx.x"):
+            for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(16), 
thread="threadIdx.y"):
+                for 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in 
T.thread_binding(T.int64(32), thread="threadIdx.x"):
+                    for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
+                        for ax2_0 in T.serial(T.int64(22), 
annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
+                            for ax2_1 in T.thread_binding(T.int64(16), 
thread="threadIdx.y"):
+                                for ax2_2 in T.thread_binding(T.int64(32), 
thread="threadIdx.x"):
+                                    for ax2_3 in T.vectorized(T.int64(1)):
+                                        with T.block("lv574_shared"):
+                                            v0, v1 = T.axis.remap("SS", [ax0, 
ax1])
+                                            v2 = 
T.axis.spatial(T.int64(11008), ax2_0 * T.int64(512) + ax2_1 * T.int64(32) + 
ax2_2 + ax2_3)
+                                            T.where((ax2_0 * T.int64(16) + 
ax2_1) * T.int64(32) + ax2_2 + ax2_3 < T.int64(11008))
+                                            T.reads(lv574[v0, v1, v2])
+                                            T.writes(lv574_shared[v0, v1, v2])
+                                            lv574_shared[v0, v1, v2] = 
lv574[v0, v1, v2]
+                    for u_fused_ax0_fused_fused_2_init in range(T.int64(1)):
+                        for 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in 
T.vectorized(T.int64(4)):
+                            with T.block("NT_matmul_rf_init"):
+                                
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = 
T.axis.spatial(T.int64(128), 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init)
+                                v0 = T.axis.spatial(T.int64(4096), 
u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + 
u_fused_ax0_fused_fused_2_init)
+                                T.reads()
+                                
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 T.int64(0), T.int64(0), v0])
+                                
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 T.int64(0), T.int64(0), v0] = T.float16(0)
+                    for ax1_0_fused_ax1_1_fused_0 in T.serial(T.int64(43), 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                        for ax0_0, ax1 in T.grid(T.int64(1), T.int64(1)):
+                            for ax0_1 in T.vectorized(T.int64(1)):
+                                with T.block("lv575_local"):
+                                    v0 = T.axis.spatial(T.int64(4096), 
u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + ax0_0 + 
ax0_1)
+                                    v1 = T.axis.spatial(T.int64(1376), 
ax1_0_fused_ax1_1_fused_0 * T.int64(32) + 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
+                                    T.reads(lv575[v0, v1])
+                                    T.writes(lv575_local[v0, v1])
+                                    lv575_local[v0, v1] = lv575[v0, v1]
+                        for u_fused_ax0_fused_fused_2, 
ax1_0_fused_ax1_1_fused_2 in T.grid(T.int64(1), T.int64(2)):
+                            for 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in 
T.vectorized(T.int64(4)):
+                                with T.block("NT_matmul_rf_update"):
+                                    
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = 
T.axis.spatial(T.int64(128), 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + 
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1)
+                                    v0 = T.axis.spatial(T.int64(4096), 
u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + 
u_fused_ax0_fused_fused_2)
+                                    vax1_0_fused_ax1_1_fused_0, 
vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, 
ax1_0_fused_ax1_1_fused_2])
+                                    
T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 T.int64(0), T.int64(0), v0], lv574_shared[T.int64(0), T.int64(0), 
vax1_0_fused_ax1_1_fused_0 * T.int64(256) + 
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * 
T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + 
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)], 
lv575_local[v0, vax1_0_fused_ax1_1_fus [...]
+                                    
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 T.int64(0), T.int64(0), v0])
+                                    
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 T.int64(0), T.int64(0), v0] = 
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
 T.int64(0), T.int64(0), v0] + lv574_shared[T.int64(0), T.int64(0), 
vax1_0_fused_ax1_1_fused_0 * T.int64(256) + 
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * 
T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int6 [...]
+            for ax2_fused_0 in T.thread_binding(T.int64(16), 
thread="threadIdx.y"):
+                for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
+                    for ax2_fused_1_0 in T.serial(T.int64(1), 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                        for ax2_fused_1_1 in T.vectorized(T.int64(1)):
+                            with T.block("NT_matmul_rf_init"):
+                                
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = 
T.axis.spatial(T.int64(32), ax0)
+                                v0 = T.axis.spatial(T.int64(4096), 
u_fused_ax0_fused_fused_0 * T.int64(16) + ax2_fused_0 + ax2_fused_1_0 + 
ax2_fused_1_1)
+                                T.reads()
+                                
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 T.int64(0), T.int64(0), v0])
+                                
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 T.int64(0), T.int64(0), v0] = T.float16(0)
+                            for ax1 in range(T.int64(4)):
+                                with T.block("NT_matmul_rf_update"):
+                                    
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = 
T.axis.remap("SR", [ax0, ax1])
+                                    v0 = T.axis.spatial(T.int64(4096), 
u_fused_ax0_fused_fused_0 * T.int64(16) + ax2_fused_0 + ax2_fused_1_0 + 
ax2_fused_1_1)
+                                    
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 T.int64(0), T.int64(0), v0], 
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0
 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 
T.int64(0), T.int64(0), v0])
+                                    
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 T.int64(0), T.int64(0), v0])
+                                    
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 T.int64(0), T.int64(0), v0] = 
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 T.int64(0), T.int64(0), v0] + 
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0
 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 
T.int64(0), T.int64(0), v0]
+            for ax1_fused_1 in range(T.int64(1)):
+                for ax1_fused_0 in T.thread_binding(T.int64(16), 
thread="threadIdx.y"):
+                    for ax0 in T.thread_binding(T.int64(32), 
thread="threadIdx.x"):
+                        with T.block("NT_matmul"):
+                            
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = 
T.axis.reduce(T.int64(32), ax0)
+                            v0 = T.axis.spatial(T.int64(4096), 
u_fused_ax0_fused_fused_0 * T.int64(16) + ax1_fused_0 + ax1_fused_1)
+                            
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 T.int64(0), T.int64(0), v0])
+                            
T.writes(var_NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0])
+                            with T.init():
+                                var_NT_matmul_intermediate_local[T.int64(0), 
T.int64(0), v0] = T.float16(0)
+                            var_NT_matmul_intermediate_local[T.int64(0), 
T.int64(0), v0] = var_NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] 
+ 
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
 T.int64(0), T.int64(0), v0]
+            for ax0_fused_0 in T.thread_binding(T.int64(16), 
thread="threadIdx.y"):
+                for ax0_fused_1 in range(T.int64(1)):
+                    with T.block("T_add"):
+                        v0 = T.axis.spatial(T.int64(4096), 
u_fused_ax0_fused_fused_0 * T.int64(16) + ax0_fused_0 + ax0_fused_1)
+                        T.reads(lv570[T.int64(0), T.int64(0), v0], 
var_NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0])
+                        T.writes(p_output0_intermediate[T.int64(0), 
T.int64(0), v0])
+                        p_output0_intermediate[T.int64(0), T.int64(0), v0] = 
lv570[T.int64(0), T.int64(0), v0] + 
var_NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0]
+
+    # fmt: on
+
+    mod = tvm.IRModule({"main": before})
+    with Target("nvidia/geforce-rtx-3090-ti"):
+        mod = dl.ApplyDefaultSchedule(dl.gpu.GEMV())(mod)
+    mod.show(black_format=False)
+    tvm.ir.assert_structural_equal(mod["main"], expected)
+
 
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to