This is an automated email from the ASF dual-hosted git repository.

hongyij pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ff0b99c5ce [Dlight] Scheduling Low batch GEMM using GEMV-like rule 
(#16579)
ff0b99c5ce is described below

commit ff0b99c5ce4371ec966cd4fa07ae36351faf2a5e
Author: Hongyi Jin <[email protected]>
AuthorDate: Wed Feb 21 11:31:54 2024 -0500

    [Dlight] Scheduling Low batch GEMM using GEMV-like rule (#16579)
    
    * low batch
    
    * fix
    
    * fix lint
    
    * do dequantize only once
    
    * change default
    
    * add test
    
    * fix lint
    
    * fix lint
---
 python/tvm/dlight/gpu/__init__.py              |   1 +
 python/tvm/dlight/gpu/low_batch_gemv.py        | 605 +++++++++++++++++++++++++
 src/driver/driver_api.cc                       |   9 +-
 src/tir/transforms/hoist_expression.cc         |   9 +-
 tests/python/dlight/test_gpu_low_batch_gemv.py | 255 +++++++++++
 5 files changed, 876 insertions(+), 3 deletions(-)

diff --git a/python/tvm/dlight/gpu/__init__.py 
b/python/tvm/dlight/gpu/__init__.py
index 7db383a161..077fdcaeb0 100644
--- a/python/tvm/dlight/gpu/__init__.py
+++ b/python/tvm/dlight/gpu/__init__.py
@@ -19,6 +19,7 @@ GPU-generic schedule rules.
 For CUDA/ROCm/Vulkan/Metal-specific rules, use 
`tvm.dlight.cuda/rocm/vulkan/metal` instead
 """
 from .gemv import GEMV
+from .low_batch_gemv import LowBatchGEMV
 from .fallback import Fallback
 from .matmul import Matmul
 from .reduction import Reduction
diff --git a/python/tvm/dlight/gpu/low_batch_gemv.py 
b/python/tvm/dlight/gpu/low_batch_gemv.py
new file mode 100644
index 0000000000..dfed020853
--- /dev/null
+++ b/python/tvm/dlight/gpu/low_batch_gemv.py
@@ -0,0 +1,605 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""A rule for low-batch GEMM / decode-GEMM using GEMV schedule."""
+import re
+from functools import reduce
+from typing import List, Optional, Union, Set
+
+from tvm import DataType, arith, ir, tir
+from tvm.target import Target
+
+from ..base import (
+    BlockInfo,
+    collect_block_iter_vars_used_in_access_region,
+    collect_vars_used_in_prim_expr,
+    is_broadcast_epilogue,
+    normalize_prim_func,
+    try_inline_contiguous_spatial,
+)
+from .base import GPUScheduleRule
+
+
+def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]:
+    # Detect and return `Y` in `X[...] = X[...] + Y`
+    buffer_store = block.body
+    if not isinstance(buffer_store, tir.BufferStore):
+        return None
+    if not isinstance(buffer_store.value, tir.Add):
+        return None
+    if not ir.structural_equal(
+        buffer_store.value.a,
+        tir.BufferLoad(buffer_store.buffer, block.body.indices),
+        map_free_vars=True,
+    ):
+        return None
+    return buffer_store.value.b
+
+
+def get_extent(sch: tir.Schedule, loop_rv: tir.schedule.LoopRV):
+    loop: tir.For = sch.get(loop_rv)
+    return loop.extent.value if isinstance(loop.extent, tir.IntImm) else 
loop.extent
+
+
+def get_bytes(dtype: Union[DataType, str]) -> int:
+    num = re.findall(r"\d+", dtype)
+    if len(num) != 1:
+        raise ValueError(f"Cannot get bytes from {dtype}")
+    return int(num[0]) // 8
+
+
+def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> 
Optional[List[tir.Buffer]]:
+    """Check if the block is a low batch GEMM.
+
+    Parameters
+    ----------
+
+    sch : tir.Schedule
+        The schedule
+
+    block_info : BlockInfo
+        The block info to be checked
+
+
+    Returns
+    -------
+    ret : Optional[List[tir.Buffer]]
+        The vector-like buffers used in the low batch GEMM if it is a low 
batch GEMM,
+        otherwise None.
+    """
+    block = block_info.block_rv
+    block_stmt = sch.get(block)
+    conditions = []
+    conditions.append(block_info.is_reduction())
+    conditions.append(len(block_stmt.reads) >= 2)
+    conditions.append(len(block_stmt.writes) == 1)
+    conditions.append(_get_reduction_expr(block_stmt) is not None)
+    conditions.append(
+        len(collect_block_iter_vars_used_in_access_region(block_stmt, 
block_stmt.writes[0].region))
+        > 0
+    )
+    if not all(conditions):
+        return None
+    const_iter_vars = set(
+        iter_var.var
+        for iter_var in block_stmt.iter_vars
+        if isinstance(iter_var.dom.extent, tir.IntImm)
+    )
+    if len(const_iter_vars) == len(block_stmt.iter_vars):
+        return None
+    ret = [
+        read.buffer
+        for read in block_stmt.reads
+        if len(
+            collect_block_iter_vars_used_in_access_region(block_stmt, 
read.region) & const_iter_vars
+        )
+        < len(const_iter_vars)
+        and len(
+            collect_block_iter_vars_used_in_access_region(block_stmt, 
read.region) & const_iter_vars
+        )
+        > 0
+    ]
+    return ret if 0 < len(ret) < len(block_stmt.reads) else None
+
+
+def detect_dominant_read(block: tir.Block, const_iter_vars: Set[tir.Var]) -> 
tir.PrimExpr:
+    """Detect the dominant read indices in the block."""
+    dominant_read = None
+    num_read_iters = -1
+    for buffer_region in block.reads:
+        tir_vars = (
+            collect_block_iter_vars_used_in_access_region(block, 
buffer_region.region)
+            & const_iter_vars
+        )
+        if num_read_iters < len(tir_vars):
+            num_read_iters = len(tir_vars)
+            dominant_read = buffer_region
+    assert dominant_read is not None
+    (result,) = dominant_read.buffer.offset_of([e.min for e in 
dominant_read.region])
+    return result
+
+
+def normalize(
+    sch: tir.Schedule,
+    block_info: BlockInfo,
+) -> Optional[bool]:
+    """Normalize the main block."""
+    block_stmt: tir.Block = sch.get(block_info.block_rv)
+    const_iter_vars = set(
+        iter_var.var
+        for iter_var in block_stmt.iter_vars
+        if isinstance(iter_var.dom.extent, tir.IntImm)
+    )
+    dynamic_iter_vars = set(
+        iter_var.var for iter_var in block_stmt.iter_vars if iter_var.var not 
in const_iter_vars
+    )
+    access = arith.normalize_to_iter_sum(
+        detect_dominant_read(block_stmt, const_iter_vars),
+        input_iters={i.var: i.dom for i in block_stmt.iter_vars},
+    )
+    buffers_use_vars = [
+        collect_block_iter_vars_used_in_access_region(block_stmt, buf.region)
+        for buf in block_stmt.writes
+    ]
+    buffers_use_vars.extend(
+        [
+            collect_block_iter_vars_used_in_access_region(block_stmt, 
buf.region)
+            for buf in block_stmt.reads
+        ]
+    )
+    if collect_vars_used_in_prim_expr(access.base) & set(
+        iter_var.var for iter_var in block_stmt.iter_vars
+    ):
+        return None
+    iter_to_info = {i.var: i for i in block_info.iters}
+    batch_loops, s_loops, r_loops, c_loops = [], [], [], []
+    inner_axis = access.args[-1].source.source
+    is_inner_reduction = iter_to_info[inner_axis].kind == "R"
+
+    for split_expr in access.args:
+        var = split_expr.source.source
+        info = iter_to_info.get(var)
+        loop = info.loop_rv
+        is_reduction = info.kind == "R"
+        if split_expr.lower_factor > 1:
+            if c_loops:
+                return None
+            loop, c_loop = sch.split(loop, factors=[None, 
split_expr.lower_factor])
+            # we only support the reduction dim being grouped atm
+            if not is_reduction:
+                return None
+            c_loops.append(c_loop)
+        if is_reduction:
+            r_loops.append(loop)
+        elif all([var in buf_vars for buf_vars in buffers_use_vars]):
+            batch_loops.append(loop)
+        else:
+            s_loops.append(loop)
+
+    assert s_loops
+    assert r_loops
+    if not c_loops:
+        c_loops = [sch.add_unit_loop(block_info.block_rv)]
+    dynamic_loops = [iter_to_info[var].loop_rv for var in dynamic_iter_vars]
+    assert len(dynamic_loops) == 1
+    if not batch_loops:
+        batch_loops = [sch.add_unit_loop(block_info.block_rv)]
+    sch.reorder(*dynamic_loops, *batch_loops, *s_loops, *r_loops, *c_loops)
+    sch.fuse(*batch_loops)
+    sch.fuse(*s_loops)
+    sch.fuse(*r_loops)
+    return is_inner_reduction
+
+
+class LowBatchGEMV(GPUScheduleRule):
+    """A rule for low batch GEMM / decode-GEMM."""
+
+    def __init__(self, bucket=4):
+        self.bucket = bucket
+
+    def apply(  # pylint: 
disable=too-many-locals,too-many-branches,too-many-return-statements
+        self,
+        func: tir.PrimFunc,
+        target: Target,
+        _: bool,
+    ) -> Union[None, tir.Schedule, List[tir.Schedule]]:
+        if not isinstance(func, tir.PrimFunc) or not 
self.is_target_available(target):
+            return None
+        sch = tir.Schedule(func)
+        block_infos = normalize_prim_func(sch)
+
+        reduction_block_infos = [
+            block_info for block_info in block_infos if 
block_info.is_reduction()
+        ]
+        if len(reduction_block_infos) != 1:
+            return None
+        reduction_block_info = reduction_block_infos[0]
+        vector_input_buffers = is_gemv(sch, reduction_block_info)
+        if vector_input_buffers is None:
+            return None
+        batch_pad = self.bucket
+        pad_value = [
+            iter.dom if isinstance(iter.dom, int) else batch_pad
+            for iter in reduction_block_info.iters
+        ]
+        sch.pad_einsum(reduction_block_info.block_rv, pad_value)
+        block_infos = normalize_prim_func(sch)
+        dequantize_block = None
+        pad_input_block = None
+        for block_info in block_infos:
+            if "dequantize" in block_info.name:
+                dequantize_block = block_info.block_rv
+            elif "pad" in block_info.name and 
len(sch.get_producers(block_info.block_rv)) == 0:
+                pad_input_block = block_info.block_rv
+        block_infos = [
+            block_info
+            for block_info in block_infos
+            if "pad" not in block_info.name and "dequantize" not in 
block_info.name
+        ]
+        block_infos = try_inline_contiguous_spatial(sch, block_infos)
+        if len(block_infos) == 1:
+            epilogue = None
+        elif len(block_infos) == 2:
+            epilogue = block_infos[1]
+            if not epilogue.is_injective():
+                return None
+        else:
+            return None
+
+        block_info = block_infos[0]
+        if len(block_info.iters) not in [2, 3]:
+            # either [B, S, R] = [B, S, R] * [B, R]
+            # or [S, R] = [S, R] * [R]
+            return None
+        block = block_info.block_rv
+        vector_input_buffers = is_gemv(sch, block_info)
+        if vector_input_buffers is None:
+            return None
+
+        # Step 1. Normalize the block, merge spatial and reduction iters
+        is_inner_reduction = normalize(sch, block_info)
+        # Step 2. Do the scheduling
+        if is_inner_reduction is None:
+            return None
+        elif is_inner_reduction:
+            self.sch_inner_reduction(
+                sch,
+                target,
+                block,
+                dequantize_block,
+                pad_input_block,
+                vector_input_buffers,
+                epilogue,
+                batch_pad,
+            )
+            return sch
+        else:
+            raise NotImplementedError("Outer reduction is not supported yet")
+
+    def sch_inner_reduction(  # pylint: disable=too-many-arguments, 
invalid-name, unused-argument
+        self,
+        sch: tir.Schedule,
+        target: Target,
+        block: tir.schedule.BlockRV,
+        dequantize_block: Optional[tir.schedule.BlockRV],
+        pad_input_block: Optional[tir.schedule.BlockRV],
+        vector_input_buffers: List[tir.Buffer],
+        epilogue_info: Optional[BlockInfo],
+        batch_pad: int,
+    ):
+        """Schedule the inner reduction block."""
+
+        def get_max_factor(n, factors):
+            factors = sorted(factors, reverse=True)
+            for factor in factors:
+                if n % factor == 0:
+                    return factor
+            return 1
+
+        def apply(
+            sch: tir.Schedule,
+            gemv,
+            TAG_S,
+            TAG_R,
+            TS,
+            TR,
+            TILE_S,
+            TILE_R,
+            VEC_LOAD,
+            VEC_C,
+            LOAD_V_SHARED,
+            LOAD_V_VEC,
+            UNROLL,
+        ):
+            # rfactor: reduce to tx * vec_c
+
+            _, b, s, r, c = sch.get_loops(block=gemv)
+            s = sch.fuse(b, s)
+            r = sch.fuse(r, c)
+            bx, ts, tile_s = sch.split(s, factors=[None, TS, TILE_S], 
preserve_unit_iters=True)
+            r, tr, tile_r_vec_n, vec_c = sch.split(
+                r, factors=[None, TR, TILE_R // VEC_C, VEC_C], 
preserve_unit_iters=True
+            )
+            sch.reorder(r, tile_r_vec_n, tr, vec_c)
+            tr_vec_c = sch.fuse(tr, vec_c)
+            rf = sch.rfactor(tr_vec_c, 0)
+
+            # rfactor: reduce to tx
+            _, bx, ts, tile_s, tr_vec_c = sch.get_loops(block=gemv)
+            tr, vec_c = sch.split(tr_vec_c, factors=[TR, None], 
preserve_unit_iters=True)
+            rf2 = sch.rfactor(tr, 0)
+            # bind, vectorize compute
+            batch_loop, bx, ts, tile_s, r, tile_r_vec_n, tr_vec_c = 
sch.get_loops(block=rf)
+            tr, vec_c = sch.split(tr_vec_c, factors=[TR, None], 
preserve_unit_iters=True)
+            sch.reorder(bx, ts, tr, r, tile_s, tile_r_vec_n, vec_c)
+            sch.bind(bx, "blockIdx.x")
+            sch.bind(ts, TAG_S)
+            sch.bind(tr, TAG_R)
+            sch.vectorize(vec_c)
+            by, batch = sch.split(batch_loop, factors=[None, batch_pad])
+            sch.bind(by, "blockIdx.y")
+            sch.reorder(bx, ts, tr, r, batch)
+
+            shared_mem_usage = 0
+            for buf in vector_input_buffers:
+                buf_size = reduce(
+                    lambda x, y: x * y, buf.shape, 
tir.IntImm(buf.shape[0].dtype, 1)
+                ) * get_bytes(buf.dtype)
+                shared_mem_usage += buf_size
+            LOAD_V_SHARED = (
+                LOAD_V_SHARED
+                and isinstance(shared_mem_usage, tir.IntImm)
+                and shared_mem_usage.value <= 
target.max_shared_memory_per_block
+            )
+
+            # vectorize load A
+            # (TODO) this is now actually problematic since the number of 
loops is dependent on the
+            # number of dimensions of A_q
+            if dequantize_block is not None:
+                sch.compute_at(dequantize_block, r, preserve_unit_loops=True)
+                sch.set_scope(dequantize_block, 0, "local")
+
+                s_local, r_local = sch.get_loops(block=dequantize_block)[-2:]
+                s_local, vec_load = sch.split(
+                    s_local, factors=[None, VEC_LOAD], preserve_unit_iters=True
+                )
+                sch.reorder(s_local, r_local, vec_load)  # either s_local or 
r_local should be 1
+                sch.vectorize(vec_load)
+
+            # load vector into shared memory, shape should be the whole vector
+            if LOAD_V_SHARED:
+                assert len(vector_input_buffers) == 1
+                V_shared = sch.cache_read(rf, read_buffer_index=0, 
storage_scope="shared")
+                sch.compute_at(V_shared, tr, preserve_unit_loops=True)
+                l = sch.get_loops(block=V_shared)[-1]
+                loop: tir.For = sch.get(l)
+                if isinstance(loop.extent, tir.IntImm):
+                    # avoid introducing predicates when vector length is too 
large
+                    vec_length = max(
+                        min(
+                            get_max_factor(
+                                (int)(loop.extent),
+                                [TS * TR * 1, TS * TR * 2, TS * TR * 4, TS * 
TR * 8],
+                            )
+                            // TS
+                            // TR,
+                            LOAD_V_VEC,
+                        ),
+                        1,
+                    )
+                else:
+                    vec_length = LOAD_V_VEC
+                if TAG_R == "threadIdx.x":
+                    _, ty, tx, vec = sch.split(
+                        l, factors=[None, TS, TR, vec_length], 
preserve_unit_iters=True
+                    )
+                else:
+                    _, ty, tx, vec = sch.split(
+                        l, factors=[None, TR, TS, vec_length], 
preserve_unit_iters=True
+                    )
+                sch.bind(ty, "threadIdx.y")
+                sch.bind(tx, "threadIdx.x")
+                sch.vectorize(vec)
+            if pad_input_block is not None:
+                sch.compute_inline(pad_input_block)
+
+            # reduce tile_s * tr * vec to tile_s * tr
+            sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True)
+            tr, vec_c, batch_loop, *ts_tile_s = sch.get_loops(block=rf2)[2:]
+            ts_tile_s = sch.fuse(*ts_tile_s)
+            ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], 
preserve_unit_iters=True)
+            tile_s, vec_s = sch.split(
+                tile_s,
+                factors=[None, get_max_factor(TILE_S, [1, 2, 4, 8])],
+                preserve_unit_iters=True,
+            )
+            sch.reorder(ts, tr, tile_s, batch_loop, vec_s, vec_c)
+            sch.bind(ts, TAG_S)
+            sch.bind(tr, TAG_R)
+            sch.vectorize(vec_s)
+
+            # reduce tile_s * tr to tile_s
+            sch.reverse_compute_at(gemv, loop=bx, preserve_unit_loops=True)
+
+            tr, batch_loop, *ts_tile_s = sch.get_loops(block=gemv)[2:]
+            ts_tile_s = sch.fuse(*ts_tile_s)
+            ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], 
preserve_unit_iters=True)
+            sch.reorder(tile_s, batch_loop, ts, tr)
+            sch.bind(ts, TAG_S)
+            sch.bind(tr, TAG_R)
+
+            sch.decompose_reduction(rf, loop=sch.get_loops(block=rf)[4])
+            sch.decompose_reduction(rf2, loop=sch.get_loops(block=rf2)[-1])
+
+            sch.set_scope(rf, buffer_index=0, storage_scope="local")
+            sch.set_scope(rf2, buffer_index=0, storage_scope="local")
+
+            unroll_factor = UNROLL
+
+            sch.annotate(
+                block_or_loop=sch.get_loops(rf)[4],
+                ann_key="pragma_auto_unroll_max_step",
+                ann_val=unroll_factor,
+            )
+            sch.annotate(
+                block_or_loop=sch.get_loops(rf)[4], 
ann_key="pragma_unroll_explicit", ann_val=1
+            )
+
+            sch.annotate(
+                block_or_loop=sch.get_loops(rf2)[4],
+                ann_key="pragma_auto_unroll_max_step",
+                ann_val=unroll_factor,
+            )
+            sch.annotate(
+                block_or_loop=sch.get_loops(rf2)[4], 
ann_key="pragma_unroll_explicit", ann_val=1
+            )
+
+            if LOAD_V_SHARED:
+                sch.annotate(
+                    block_or_loop=sch.get_loops(V_shared)[-4],
+                    ann_key="pragma_unroll_explicit",
+                    ann_val=unroll_factor,
+                )
+                sch.annotate(
+                    block_or_loop=sch.get_loops(V_shared)[-4], 
ann_key="pragma_vectorize", ann_val=1
+                )
+
+            epilogue = sch.get_consumers(gemv)
+            # Schedule epilogue
+            if epilogue:
+                epilogue = epilogue[0]
+                if is_broadcast_epilogue(sch, block, epilogue):
+                    sch.reverse_compute_at(epilogue, bx)
+                    sch.set_scope(block, 0, "shared")
+                    _, _, _, *s = sch.get_loops(epilogue)  # pylint: 
disable=invalid-name
+                    _, tx = sch.split(sch.fuse(*s), factors=[None, TX])
+                    sch.bind(tx, "threadIdx.x")
+                else:
+                    sch.reverse_compute_at(epilogue, bx, 
preserve_unit_loops=True)
+                    ts_tile_s = sch.fuse(*sch.get_loops(epilogue)[3:])
+                    ts_tile_s = sch.get_loops(epilogue)[-1]
+                    ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], 
preserve_unit_iters=True)
+                    sch.bind(ts, TAG_S)
+                    sch.set_scope(block, 0, "local")
+
+            return sch
+
+        # Specify the `len_tx` and `len_ty` according to the loop extent
+        _, batch, s, r, c = sch.get_loops(block=block)
+        len_batch, len_s, len_r, len_c = (
+            get_extent(sch, batch),
+            get_extent(sch, s),
+            get_extent(sch, r),
+            get_extent(sch, c),
+        )
+        len_S = len_batch * len_s
+        len_R = len_r * len_c
+
+        TAG_S, TAG_R = "threadIdx.y", "threadIdx.x"
+        if target.kind.name == "cuda":
+            VEC_C = 4
+            LOAD_V_SHARED = True
+            LOAD_V_VEC = 8
+            UNROLL = 256
+            if isinstance(len_S, int):
+                if len_S > len_R:
+                    TS, TR = 4, 64
+                else:
+                    TS, TR = 16, 32
+        elif target.kind.name == "metal":
+            # Note that the following tile size is tuned on M2 Ultra for 7B
+            TAG_S, TAG_R = "threadIdx.x", "threadIdx.y"
+            VEC_C = 1
+            LOAD_V_SHARED = False
+            LOAD_V_VEC = -1
+            UNROLL = 256
+            if isinstance(len_S, int):
+                if len_S > len_R:
+                    TS, TR = 2, 32
+                else:
+                    TS, TR = 2, 64
+        elif target.kind.name == "rocm":
+            VEC_C = 4
+            LOAD_V_SHARED = True
+            LOAD_V_VEC = 8
+            UNROLL = 256
+            if isinstance(len_S, int):
+                if len_S > len_R:
+                    TS, TR = 1, 128
+                else:
+                    TS, TR = 8, 64
+        elif target.kind.name == "opencl" and "android" in str(target.host):
+            TAG_S, TAG_R = "threadIdx.x", "threadIdx.y"
+            VEC_C = 8
+            LOAD_V_SHARED = False
+            LOAD_V_VEC = -1
+            UNROLL = 8
+            TS, TR = 2, 32
+        elif target.kind.name == "vulkan":
+            VEC_C = 4
+            LOAD_V_SHARED = True
+            LOAD_V_VEC = 4
+            UNROLL = 256
+            if isinstance(len_S, int):
+                if len_S > len_R:
+                    TS, TR = 4, 32
+                else:
+                    TS, TR = 16, 32
+        elif target.kind.name == "opencl" and "mali" in str(target.attrs):
+            VEC_C = 8
+            LOAD_V_SHARED = False
+            LOAD_V_VEC = -1
+            UNROLL = 64
+            TS, TR = 1, 64
+        else:
+            VEC_C = 1
+            LOAD_V_SHARED = False
+            LOAD_V_VEC = -1
+            UNROLL = 64
+            TS, TR = 1, 64
+
+        if not isinstance(len_S, int):
+            TS, TR = 1, 64
+
+        while TS * TR > target.max_num_threads:
+            if TS > 1:
+                TS //= 2
+            else:
+                TR //= 2
+
+        TILE_S, TILE_R = (
+            2,
+            len_c
+            if len_c > 1
+            else max(get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8]) 
// TR, 1),
+        )
+        VEC_C = min(get_max_factor(TILE_R, [1, 2, 4, 8]), VEC_C)
+        VEC_LOAD = 1
+        return apply(
+            sch,
+            gemv=block,
+            TAG_S=TAG_S,
+            TAG_R=TAG_R,
+            TS=TS,
+            TR=TR,
+            TILE_S=TILE_S,
+            TILE_R=TILE_R,
+            VEC_LOAD=VEC_LOAD,
+            VEC_C=VEC_C,
+            LOAD_V_SHARED=LOAD_V_SHARED,
+            LOAD_V_VEC=LOAD_V_VEC,
+            UNROLL=UNROLL,
+        )
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index 4eca8aebd7..bdadb6db0f 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -240,6 +240,10 @@ Array<tvm::transform::Pass> CreatePassList(bool 
disable_loop_partition) {
   if (use_async_copy) {
     pass_list.push_back(tir::transform::LowerAsyncDMA());
   }
+  // HoistIfThenElse must be applied before UnrollLoop
+  // because HoistIfThenElse could utilize for loop structure
+  // which might be unrolled in UnrollLoop
+  pass_list.push_back(tir::transform::HoistIfThenElse());
   pass_list.push_back(tir::transform::UnrollLoop());
 
   // Add user-defined phase-2 passes
@@ -250,7 +254,6 @@ Array<tvm::transform::Pass> CreatePassList(bool 
disable_loop_partition) {
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::RemoveNoOp());
   pass_list.push_back(tir::transform::RewriteUnsafeSelect());
-  pass_list.push_back(tir::transform::HoistIfThenElse());
 
   // Add user-defined phase-3 passes
   pass_list.insert(pass_list.end(), user_lower_phase3.begin(), 
user_lower_phase3.end());
@@ -586,7 +589,6 @@ transform::Sequential MixedModulePassManager(IRModule 
mixed_mod, Target target)
 
   mixed_pass_list.push_back(tir::transform::ThreadSync("shared"));
   mixed_pass_list.push_back(tir::transform::ThreadSync("shared.dyn"));
-  mixed_pass_list.push_back(tir::transform::MergeSharedMemoryAllocations());
   mixed_pass_list.push_back(tir::transform::ThreadSync("warp"));
   mixed_pass_list.push_back(tir::transform::InferFragment());
   mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
@@ -604,6 +606,9 @@ transform::Sequential MixedModulePassManager(IRModule 
mixed_mod, Target target)
 
   mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions());
   mixed_pass_list.push_back(tir::transform::SplitHostDevice());
+  // MergeSharedMemoryAllocations must be applied after SplitHostDevice
+  // because the merged allocation site is at the beginning of each device 
function
+  mixed_pass_list.push_back(tir::transform::MergeSharedMemoryAllocations());
 
   bool unpacked_api = mixed_mod->GetAttr<relay::Executor>(tvm::attr::kExecutor)
                           .value_or(relay::Executor::Create("graph", {}))
diff --git a/src/tir/transforms/hoist_expression.cc 
b/src/tir/transforms/hoist_expression.cc
index 494fd7184f..f0fc90ee32 100644
--- a/src/tir/transforms/hoist_expression.cc
+++ b/src/tir/transforms/hoist_expression.cc
@@ -558,7 +558,14 @@ Pass HoistIfThenElse() {
   auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
     auto* n = f.CopyOnWrite();
     auto cfg = ctx->GetConfig<HoistIfThenElseConfig>("tir.HoistIfThenElse");
-
+    auto flag = f->GetAttr<Integer>("tir.HoistIfThenElseExprWithBlock");
+    if (flag && flag.value().IntValue() == 1) {
+      HoistExpressionConfig 
config(static_cast<int>(HoistedConditionals::kUsingBlockVar) |
+                                       
static_cast<int>(HoistedConditionals::kIfElseExpr),
+                                   
static_cast<int>(HoistedLetBindings::kNone));
+      n->body = ExpressionHoister::Hoist(std::move(n->body), config);
+      return f;
+    }
     if (!cfg.defined()) {
       cfg = AttrsWithDefaultValues<HoistIfThenElseConfig>();
     }
diff --git a/tests/python/dlight/test_gpu_low_batch_gemv.py 
b/tests/python/dlight/test_gpu_low_batch_gemv.py
new file mode 100644
index 0000000000..5827b7b810
--- /dev/null
+++ b/tests/python/dlight/test_gpu_low_batch_gemv.py
@@ -0,0 +1,255 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import pytest
+
+import tvm.testing
+from tvm import dlight as dl
+from tvm.script import tir as T
+from tvm.target import Target
+
+
+def test_batch_decode_gemv():
+    # fmt: off
+
+    @T.prim_func(private=True)
+    def before(lv429: T.Buffer((T.int64(4096), T.int64(3584)), "uint32"), 
lv430: T.Buffer((T.int64(4096), T.int64(896)), "float16"), p_lv807: T.handle, 
p_output0: T.handle):
+        T.func_attr({"tir.noalias": T.bool(True), 
"tir.HoistIfThenElseExprWithBlock": 1})
+        batch_size = T.int64()
+        lv807 = T.match_buffer(p_lv807, (batch_size, T.int64(1), 
T.int64(28672)), "float16")
+        NT_matmul_intermediate = T.match_buffer(p_output0, (batch_size, 
T.int64(1), T.int64(4096)), "float16")
+        # with T.block("root"):
+        compute = T.alloc_buffer((T.int64(4096), T.int64(28672)), "float16")
+        dequantize_intermediate_intermediate = T.alloc_buffer((T.int64(4096), 
T.int64(28672)), "float16")
+        for i0, i1 in T.grid(T.int64(4096), T.int64(28672)):
+            with T.block("compute"):
+                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                T.reads(lv429[v_i0, v_i1 // T.int64(8)])
+                T.writes(compute[v_i0, v_i1])
+                compute[v_i0, v_i1] = T.Cast("float16", 
T.bitwise_and(T.shift_right(lv429[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", 
v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
+        for i0, i1 in T.grid(T.int64(4096), T.int64(28672)):
+            with T.block("dequantize"):
+                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                T.reads(compute[v_i0, v_i1], lv430[v_i0, v_i1 // T.int64(32)])
+                T.writes(dequantize_intermediate_intermediate[v_i0, v_i1])
+                dequantize_intermediate_intermediate[v_i0, v_i1] = 
(compute[v_i0, v_i1] - T.float16(7)) * lv430[v_i0, v_i1 // T.int64(32)]
+        for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(4096), 
T.int64(28672)):
+            with T.block("NT_matmul"):
+                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+                T.reads(lv807[v_i0, v_i1, v_k], 
dequantize_intermediate_intermediate[v_i2, v_k])
+                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
+                with T.init():
+                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
+                NT_matmul_intermediate[v_i0, v_i1, v_i2] = 
NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv807[v_i0, v_i1, v_k] * 
dequantize_intermediate_intermediate[v_i2, v_k]
+
+    @T.prim_func(private=True)
+    def expected(lv429: T.Buffer((T.int64(4096), T.int64(3584)), "uint32"), 
lv430: T.Buffer((T.int64(4096), T.int64(896)), "float16"), p_lv807: T.handle, 
p_output0: T.handle):
+        T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1, 
"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+        batch_size = T.int64()
+        lv807 = T.match_buffer(p_lv807, (batch_size, T.int64(1), 
T.int64(28672)), "float16")
+        NT_matmul_intermediate = T.match_buffer(p_output0, (batch_size, 
T.int64(1), T.int64(4096)), "float16")
+        # with T.block("root"):
+        dequantize_intermediate_intermediate_local = 
T.alloc_buffer((T.int64(4096), T.int64(28672)), "float16", scope="local")
+        NT_matmul_intermediate_pad_local = T.alloc_buffer(((batch_size + 
T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16", 
scope="local")
+        NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((T.int64(64), 
(batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), 
T.int64(4096)), "float16", scope="local")
+        NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((T.int64(64), 
(batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), 
T.int64(4096)), "float16", scope="local")
+        for ax0_0 in T.thread_binding((batch_size + T.int64(3)) // T.int64(4), 
thread="blockIdx.y"):
+            for u_fused_ax1_fused_fused_0 in T.thread_binding(T.int64(1024), 
thread="blockIdx.x"):
+                for u_fused_ax1_fused_fused_1 in T.thread_binding(T.int64(2), 
thread="threadIdx.x"):
+                    for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 in 
T.thread_binding(T.int64(64), thread="threadIdx.y"):
+                        for ax0_1_init, u_fused_ax1_fused_fused_2_init in 
T.grid(T.int64(4), T.int64(2)):
+                            for 
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init in 
T.vectorized(T.int64(1)):
+                                with T.block("NT_matmul_rf_init"):
+                                    
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(64), 
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 + 
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init)
+                                    v0 = T.axis.spatial((batch_size + 
T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1_init)
+                                    v1 = T.axis.spatial(T.int64(4096), 
u_fused_ax1_fused_fused_0 * T.int64(4) + u_fused_ax1_fused_fused_1 * T.int64(2) 
+ u_fused_ax1_fused_fused_2_init)
+                                    T.reads()
+                                    
T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
 v0, T.int64(0), v1])
+                                    
NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
 v0, T.int64(0), v1] = T.float16(0)
+                        for ax2_fused_u_fused_0 in T.serial(T.int64(56), 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                            for ax0_0_1, ax1 in T.grid(T.int64(2), T.int64(8)):
+                                for ax0_1 in T.vectorized(T.int64(1)):
+                                    with T.block("dequantize"):
+                                        v0 = T.axis.spatial(T.int64(4096), 
u_fused_ax1_fused_fused_0 * T.int64(4) + u_fused_ax1_fused_fused_1 * T.int64(2) 
+ ax0_0_1 + ax0_1)
+                                        v1 = T.axis.spatial(T.int64(28672), 
ax2_fused_u_fused_0 * T.int64(512) + 
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(8) + ax1)
+                                        T.reads(lv429[v0, v1 // T.int64(8)], 
lv430[v0, v1 // T.int64(32)])
+                                        
T.writes(dequantize_intermediate_intermediate_local[v0, v1])
+                                        
dequantize_intermediate_intermediate_local[v0, v1] = (T.Cast("float16", 
T.bitwise_and(T.shift_right(lv429[v0, v1 // T.int64(8)], T.Cast("uint32", v1 % 
T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7)) * lv430[v0, v1 // 
T.int64(32)]
+                            for ax0_1, u_fused_ax1_fused_fused_2, 
ax2_fused_u_fused_2 in T.grid(T.int64(4), T.int64(2), T.int64(8)):
+                                for 
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 in T.vectorized(T.int64(1)):
+                                    with T.block("NT_matmul_rf_update"):
+                                        
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(64), 
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 + 
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1)
+                                        v0 = T.axis.spatial((batch_size + 
T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1)
+                                        v1 = T.axis.spatial(T.int64(4096), 
u_fused_ax1_fused_fused_0 * T.int64(4) + u_fused_ax1_fused_fused_1 * T.int64(2) 
+ u_fused_ax1_fused_fused_2)
+                                        vax2_fused_u_fused_0, 
vax2_fused_u_fused_2 = T.axis.remap("RR", [ax2_fused_u_fused_0, 
ax2_fused_u_fused_2])
+                                        
T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
 v0, T.int64(0), v1], lv807[v0, T.int64(0), vax2_fused_u_fused_0 * T.int64(512) 
+ vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused * T.int64(8) + 
vax2_fused_u_fused_2], dequantize_intermediate_intermediate_local[v1, 
vax2_fused_u_fused_0 * T.int64(512) + 
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused * T.int64(8) + 
vax2_fused_u_fused_2])
+                                        
T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
 v0, T.int64(0), v1])
+                                        
NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
 v0, T.int64(0), v1] = 
NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
 v0, T.int64(0), v1] + T.if_then_else(v0 < batch_size, lv807[v0, T.int64(0), 
vax2_fused_u_fused_0 * T.int64(512) + 
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused * T.int64(8) + 
vax2_fused_u_fused_2], T.float16(0)) * 
dequantize_intermediate_intermediate_local[v1,  [...]
+                for ax3_fused_0 in T.thread_binding(T.int64(2), 
thread="threadIdx.x"):
+                    for ax0 in T.thread_binding(T.int64(64), 
thread="threadIdx.y"):
+                        for ax3_fused_1_0 in T.serial(T.int64(1), 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                            for ax2 in range(T.int64(4)):
+                                for ax3_fused_1_1 in T.vectorized(T.int64(2)):
+                                    with T.block("NT_matmul_rf_init"):
+                                        
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(64), 
ax0)
+                                        v0 = T.axis.spatial((batch_size + 
T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
+                                        v1 = T.axis.spatial(T.int64(4096), 
u_fused_ax1_fused_fused_0 * T.int64(4) + ax3_fused_0 * T.int64(2) + 
ax3_fused_1_0 * T.int64(2) + ax3_fused_1_1)
+                                        T.reads()
+                                        
T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
 v0, T.int64(0), v1])
+                                        
NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
 v0, T.int64(0), v1] = T.float16(0)
+                                    for ax1 in range(T.int64(1)):
+                                        with T.block("NT_matmul_rf_update"):
+                                            
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, 
ax1])
+                                            v0 = T.axis.spatial((batch_size + 
T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
+                                            v1 = T.axis.spatial(T.int64(4096), 
u_fused_ax1_fused_fused_0 * T.int64(4) + ax3_fused_0 * T.int64(2) + 
ax3_fused_1_0 * T.int64(2) + ax3_fused_1_1)
+                                            
T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
 v0, T.int64(0), v1], 
NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0
 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, v0, T.int64(0), v1])
+                                            
T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
 v0, T.int64(0), v1])
+                                            
NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
 v0, T.int64(0), v1] = 
NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
 v0, T.int64(0), v1] + 
NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0
 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, v0, T.int64(0), v1]
+                for ax2_fused_1, ax1 in T.grid(T.int64(2), T.int64(4)):
+                    for ax2_fused_0 in T.thread_binding(T.int64(2), 
thread="threadIdx.x"):
+                        for ax0 in T.thread_binding(T.int64(64), 
thread="threadIdx.y"):
+                            with T.block("NT_matmul"):
+                                
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(64), 
ax0)
+                                v0 = T.axis.spatial((batch_size + T.int64(3)) 
// T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax1)
+                                v1 = T.axis.spatial(T.int64(4096), 
u_fused_ax1_fused_fused_0 * T.int64(4) + ax2_fused_0 * T.int64(2) + ax2_fused_1)
+                                
T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
 v0, T.int64(0), v1])
+                                T.writes(NT_matmul_intermediate_pad_local[v0, 
T.int64(0), v1])
+                                with T.init():
+                                    NT_matmul_intermediate_pad_local[v0, 
T.int64(0), v1] = T.float16(0)
+                                NT_matmul_intermediate_pad_local[v0, 
T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] + 
NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
 v0, T.int64(0), v1]
+                for ax0 in range(T.int64(4)):
+                    for ax1_fused_0 in T.thread_binding(T.int64(2), 
thread="threadIdx.x"):
+                        for ax1_fused_1 in range(T.int64(2)):
+                            with T.block("NT_matmul_intermediate_pad"):
+                                v0 = T.axis.spatial(batch_size, ax0_0 * 
T.int64(4) + ax0)
+                                v1 = T.axis.spatial(T.int64(4096), 
u_fused_ax1_fused_fused_0 * T.int64(4) + ax1_fused_0 * T.int64(2) + ax1_fused_1)
+                                T.where((ax0_0 - (batch_size + T.int64(3)) // 
T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < 
batch_size)
+                                T.reads(NT_matmul_intermediate_pad_local[v0, 
T.int64(0), v1])
+                                T.writes(NT_matmul_intermediate[v0, 
T.int64(0), v1])
+                                NT_matmul_intermediate[v0, T.int64(0), v1] = 
NT_matmul_intermediate_pad_local[v0, T.int64(0), v1]
+    # fmt: on
+    mod = tvm.IRModule({"main": before})
+    with Target("metal"):
+        mod = dl.ApplyDefaultSchedule(dl.gpu.LowBatchGEMV(4))(mod)
+    tvm.ir.assert_structural_equal(mod["main"], expected)
+
+
+def test_batch_gemv():
+    N = 4096
+    K = 4096
+    # fmt: off
+    @T.prim_func(private=True)
+    def before(var_A: T.handle, B: T.Buffer((T.int64(N), T.int64(K)), 
"float16"), var_NT_matmul: T.handle):
+        T.func_attr({"tir.noalias": T.bool(True), 
"tir.HoistIfThenElseExprWithBlock": 1})
+        batch_size = T.int64()
+        A = T.match_buffer(var_A, (batch_size, T.int64(1), T.int64(K)), 
"float16")
+        NT_matmul = T.match_buffer(var_NT_matmul, (batch_size, T.int64(1), 
T.int64(N)), "float16")
+        # with T.block("root"):
+        for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(N), 
T.int64(K)):
+            with T.block("NT_matmul"):
+                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+                T.reads(A[v_i0, v_i1, v_k], B[v_i2, v_k])
+                T.writes(NT_matmul[v_i0, v_i1, v_i2])
+                with T.init():
+                    NT_matmul[v_i0, v_i1, v_i2] = T.float16(0)
+                NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + 
A[v_i0, v_i1, v_k] * B[v_i2, v_k]
+
+    @T.prim_func(private=True)
+    def expected(var_A: T.handle, B: T.Buffer((T.int64(4096), T.int64(4096)), 
"float16"), var_NT_matmul: T.handle):
+        T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1, 
"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+        batch_size = T.int64()
+        A = T.match_buffer(var_A, (batch_size, T.int64(1), T.int64(4096)), 
"float16")
+        NT_matmul = T.match_buffer(var_NT_matmul, (batch_size, T.int64(1), 
T.int64(4096)), "float16")
+        # with T.block("root"):
+        NT_matmul_pad_local = T.alloc_buffer(((batch_size + T.int64(3)) // 
T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16", scope="local")
+        NT_matmul_pad_rf_local = T.alloc_buffer((T.int64(64), (batch_size + 
T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16", 
scope="local")
+        NT_matmul_pad_rf_local_1 = T.alloc_buffer((T.int64(64), (batch_size + 
T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16", 
scope="local")
+        for ax0_0 in T.thread_binding((batch_size + T.int64(3)) // T.int64(4), 
thread="blockIdx.y"):
+            for u_fused_ax1_fused_fused_0 in T.thread_binding(T.int64(1024), 
thread="blockIdx.x"):
+                for u_fused_ax1_fused_fused_1 in T.thread_binding(T.int64(2), 
thread="threadIdx.x"):
+                    for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 in 
T.thread_binding(T.int64(64), thread="threadIdx.y"):
+                        for ax0_1_init, u_fused_ax1_fused_fused_2_init in 
T.grid(T.int64(4), T.int64(2)):
+                            for 
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init in 
T.vectorized(T.int64(1)):
+                                with T.block("NT_matmul_rf_init"):
+                                    
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(64), 
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 + 
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init)
+                                    v0 = T.axis.spatial((batch_size + 
T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1_init)
+                                    v1 = T.axis.spatial(T.int64(4096), 
u_fused_ax1_fused_fused_0 * T.int64(4) + u_fused_ax1_fused_fused_1 * T.int64(2) 
+ u_fused_ax1_fused_fused_2_init)
+                                    T.reads()
+                                    
T.writes(NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 
v0, T.int64(0), v1])
+                                    
NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, 
T.int64(0), v1] = T.float16(0)
+                        for ax2_fused_u_fused_0 in T.serial(T.int64(8), 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                            for ax0_1, u_fused_ax1_fused_fused_2, 
ax2_fused_u_fused_2 in T.grid(T.int64(4), T.int64(2), T.int64(8)):
+                                for 
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 in T.vectorized(T.int64(1)):
+                                    with T.block("NT_matmul_rf_update"):
+                                        
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(64), 
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 + 
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1)
+                                        v0 = T.axis.spatial((batch_size + 
T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1)
+                                        v1 = T.axis.spatial(T.int64(4096), 
u_fused_ax1_fused_fused_0 * T.int64(4) + u_fused_ax1_fused_fused_1 * T.int64(2) 
+ u_fused_ax1_fused_fused_2)
+                                        vax2_fused_u_fused_0, 
vax2_fused_u_fused_2 = T.axis.remap("RR", [ax2_fused_u_fused_0, 
ax2_fused_u_fused_2])
+                                        
T.reads(NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 
v0, T.int64(0), v1], A[v0, T.int64(0), vax2_fused_u_fused_0 * T.int64(512) + 
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused * T.int64(8) + 
vax2_fused_u_fused_2], B[v1, vax2_fused_u_fused_0 * T.int64(512) + 
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused * T.int64(8) + 
vax2_fused_u_fused_2])
+                                        
T.writes(NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 
v0, T.int64(0), v1])
+                                        
NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, 
T.int64(0), v1] = 
NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, 
T.int64(0), v1] + T.if_then_else(v0 < batch_size, A[v0, T.int64(0), 
vax2_fused_u_fused_0 * T.int64(512) + 
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused * T.int64(8) + 
vax2_fused_u_fused_2], T.float16(0)) * B[v1, vax2_fused_u_fused_0 * 
T.int64(512) + vax2_fused_u_fused_1_ax2_fused_u_ [...]
+                for ax3_fused_0 in T.thread_binding(T.int64(2), 
thread="threadIdx.x"):
+                    for ax0 in T.thread_binding(T.int64(64), 
thread="threadIdx.y"):
+                        for ax3_fused_1_0 in T.serial(T.int64(1), 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                            for ax2 in range(T.int64(4)):
+                                for ax3_fused_1_1 in T.vectorized(T.int64(2)):
+                                    with T.block("NT_matmul_rf_init"):
+                                        
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(64), 
ax0)
+                                        v0 = T.axis.spatial((batch_size + 
T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
+                                        v1 = T.axis.spatial(T.int64(4096), 
u_fused_ax1_fused_fused_0 * T.int64(4) + ax3_fused_0 * T.int64(2) + 
ax3_fused_1_0 * T.int64(2) + ax3_fused_1_1)
+                                        T.reads()
+                                        
T.writes(NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
 v0, T.int64(0), v1])
+                                        
NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, 
T.int64(0), v1] = T.float16(0)
+                                    for ax1 in range(T.int64(1)):
+                                        with T.block("NT_matmul_rf_update"):
+                                            
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, 
ax1])
+                                            v0 = T.axis.spatial((batch_size + 
T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
+                                            v1 = T.axis.spatial(T.int64(4096), 
u_fused_ax1_fused_fused_0 * T.int64(4) + ax3_fused_0 * T.int64(2) + 
ax3_fused_1_0 * T.int64(2) + ax3_fused_1_1)
+                                            
T.reads(NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
 v0, T.int64(0), v1], 
NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 + 
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, v0, T.int64(0), v1])
+                                            
T.writes(NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
 v0, T.int64(0), v1])
+                                            
NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, 
T.int64(0), v1] = 
NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, 
T.int64(0), v1] + 
NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 + 
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, v0, T.int64(0), v1]
+                for ax2_fused_1, ax1 in T.grid(T.int64(2), T.int64(4)):
+                    for ax2_fused_0 in T.thread_binding(T.int64(2), 
thread="threadIdx.x"):
+                        for ax0 in T.thread_binding(T.int64(64), 
thread="threadIdx.y"):
+                            with T.block("NT_matmul"):
+                                
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(64), 
ax0)
+                                v0 = T.axis.spatial((batch_size + T.int64(3)) 
// T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax1)
+                                v1 = T.axis.spatial(T.int64(4096), 
u_fused_ax1_fused_fused_0 * T.int64(4) + ax2_fused_0 * T.int64(2) + ax2_fused_1)
+                                
T.reads(NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
 v0, T.int64(0), v1])
+                                T.writes(NT_matmul_pad_local[v0, T.int64(0), 
v1])
+                                with T.init():
+                                    NT_matmul_pad_local[v0, T.int64(0), v1] = 
T.float16(0)
+                                NT_matmul_pad_local[v0, T.int64(0), v1] = 
NT_matmul_pad_local[v0, T.int64(0), v1] + 
NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, 
T.int64(0), v1]
+                for ax0 in range(T.int64(4)):
+                    for ax1_fused_0 in T.thread_binding(T.int64(2), 
thread="threadIdx.x"):
+                        for ax1_fused_1 in range(T.int64(2)):
+                            with T.block("NT_matmul_pad"):
+                                v0 = T.axis.spatial(batch_size, ax0_0 * 
T.int64(4) + ax0)
+                                v1 = T.axis.spatial(T.int64(4096), 
u_fused_ax1_fused_fused_0 * T.int64(4) + ax1_fused_0 * T.int64(2) + ax1_fused_1)
+                                T.where((ax0_0 - (batch_size + T.int64(3)) // 
T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < 
batch_size)
+                                T.reads(NT_matmul_pad_local[v0, T.int64(0), 
v1])
+                                T.writes(NT_matmul[v0, T.int64(0), v1])
+                                NT_matmul[v0, T.int64(0), v1] = 
NT_matmul_pad_local[v0, T.int64(0), v1]
+    # fmt: on
+    mod = tvm.IRModule({"main": before})
+    with Target("metal"):
+        mod = dl.ApplyDefaultSchedule(dl.gpu.LowBatchGEMV(4))(mod)
+    tvm.ir.assert_structural_equal(mod["main"], expected)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to