This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 741ca41814 [Unity][Dlight] general reduction rule for gemv-decode 
(#15169)
741ca41814 is described below

commit 741ca418143c3a53c607dfe88b4e49696aecaa85
Author: Bohan Hou <[email protected]>
AuthorDate: Sat Jul 1 22:18:27 2023 -0700

    [Unity][Dlight] general reduction rule for gemv-decode (#15169)
    
    This PR introduces a general rule for gemv-decode
    
    ---------
    
    Co-authored-by: Junru Shao <[email protected]>
    Co-authored-by: Tianqi Chen <[email protected]>
---
 python/tvm/dlight/__init__.py                      |  10 +-
 python/tvm/dlight/base/__init__.py                 |   4 +-
 python/tvm/dlight/base/analysis.py                 | 135 +++++--
 python/tvm/dlight/base/common_schedules.py         |  37 +-
 python/tvm/dlight/gpu/__init__.py                  |   1 +
 python/tvm/dlight/gpu/decode_gemv.py               | 192 +++++++++
 python/tvm/dlight/gpu/fallback.py                  |  16 +-
 python/tvm/dlight/gpu/reduction.py                 |  38 +-
 src/tir/schedule/analysis/analysis.cc              |  15 +
 src/tir/schedule/primitive/compute_inline.cc       |  21 +-
 src/tir/schedule/transform.cc                      |  52 +++
 tests/python/dlight/test_gpu_decode_gemv.py        | 432 +++++++++++++++++++++
 tests/python/dlight/test_gpu_fallback.py           |  15 +-
 tests/python/dlight/test_gpu_reduction.py          | 139 ++++---
 .../unittest/test_tir_schedule_compute_inline.py   |  24 ++
 15 files changed, 970 insertions(+), 161 deletions(-)

diff --git a/python/tvm/dlight/__init__.py b/python/tvm/dlight/__init__.py
index 23dd17993c..421d4017d1 100644
--- a/python/tvm/dlight/__init__.py
+++ b/python/tvm/dlight/__init__.py
@@ -16,4 +16,12 @@
 # under the License.
 """DLight package provides efficient schedules out-of-box for deep learning 
workloads."""
 from . import gpu
-from .base import ApplyDefaultSchedule, ScheduleRule
+from .base import (
+    ApplyDefaultSchedule,
+    BlockInfo,
+    IterInfo,
+    ScheduleRule,
+    normalize_prim_func,
+    try_inline,
+    try_inline_contiguous_spatial,
+)
diff --git a/python/tvm/dlight/base/__init__.py 
b/python/tvm/dlight/base/__init__.py
index d14db6c4a7..b69c82fca0 100644
--- a/python/tvm/dlight/base/__init__.py
+++ b/python/tvm/dlight/base/__init__.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 """Base infra"""
-from .analysis import BlockInfo
-from .common_schedules import try_inline
+from .analysis import BlockInfo, IterInfo, normalize_prim_func
+from .common_schedules import try_inline, try_inline_contiguous_spatial
 from .schedule_rule import ScheduleRule
 from .transform import ApplyDefaultSchedule
diff --git a/python/tvm/dlight/base/analysis.py 
b/python/tvm/dlight/base/analysis.py
index a2508c87ba..a5d70c6c0c 100644
--- a/python/tvm/dlight/base/analysis.py
+++ b/python/tvm/dlight/base/analysis.py
@@ -15,61 +15,134 @@
 # specific language governing permissions and limitations
 # under the License.
 """Analysis on TIR blocks, loops and functions."""
-from typing import List, Union
+from typing import List, Optional, Union
+
+from typing_extensions import Literal
 
 from tvm import tir
+from tvm._ffi import get_global_func
+
+
+class IterInfo:
+    """Information about a loop/iter var."""
+
+    kind: Literal["S", "R", "O"]
+    var: tir.Var
+    _dom: tir.PrimExpr
+    loop_rv: tir.schedule.LoopRV
+
+    def __init__(
+        self,
+        kind: Literal["S", "R", "O"],
+        var: tir.Var,
+        dom: tir.PrimExpr,
+        loop_rv: tir.schedule.LoopRV,
+    ):
+        """Construct an IterInfo object."""
+        self.kind = kind
+        self.var = var
+        self._dom = dom
+        self.loop_rv = loop_rv
+
+    @property
+    def dom(self) -> Union[int, tir.PrimExpr]:
+        """The iteration domain of the loop."""
+        return int(self._dom) if isinstance(self._dom, tir.IntImm) else 
self._dom
+
+    def __str__(self) -> str:
+        return f'Iter("{self.kind}", {self.dom})'
+
+    def __repr__(self) -> str:
+        return str(self)
 
 
 class BlockInfo:
     """Information about a TIR block."""
 
-    block: tir.schedule.BlockRV
-    """The TIR block the current schedule refers to"""
     name: str
-    """The name of the block"""
-    iters: List[tir.IterVar]
-    """The iteration domains of the current block"""
+    iters: List[IterInfo]
+    block_rv: tir.schedule.BlockRV
+    _reduction_block: bool
 
     def __init__(
         self,
-        sch: tir.Schedule,
-        block: tir.schedule.BlockRV,
+        name: str,
+        iters: List[IterInfo],
+        block_rv: tir.schedule.BlockRV,
+        reduction_block: bool = False,
     ):
-        """Construct a BlockInfo object via TIR schedule."""
-        tir_block = sch.get(block)
-        self.block = block
-        self.name = tir_block.name_hint
-        self.iters = list(tir_block.iter_vars)
+        """Construct a BlockInfo object."""
+        self.name = name
+        self.block_rv = block_rv
+        self.iters = iters
+        self._reduction_block = reduction_block
 
     def dom(self) -> List[Union[int, tir.PrimExpr]]:
         """The iteration domain of the block."""
-
-        def _iter_dom(i: tir.IterVar) -> Union[int, tir.PrimExpr]:
-            result = i.dom.extent
-            if isinstance(result, tir.IntImm):
-                result = int(result)
-            return result
-
-        result = [_iter_dom(i) for i in self.iters]
-        return result
+        return [i.dom for i in self.iters]
 
     def dom_kind(self) -> str:
         """The iteration domain kind of the block, for example, SSSS, SSSR."""
+        return "".join(i.kind for i in self.iters)
 
-        def _iter_kind(i: tir.IterVar) -> str:
-            return {
-                tir.IterVar.DataPar: "S",
-                tir.IterVar.CommReduce: "R",
-            }.get(i.iter_type, "O")
+    def is_injective(self) -> bool:
+        """Whether the block is injective, i.e. all its iteration domains are 
injective."""
+        return all(k == "S" for k in self.dom_kind())
 
-        return "".join(_iter_kind(i) for i in self.iters)
+    def is_reduction(self) -> bool:
+        """Whether the block is a reduction workload."""
+        # TODO(@junrushao): distinguish GEMV and reduction
+        return self._reduction_block
 
-    def is_spatial(self) -> bool:
-        """Whether the block is spatial, i.e. all its iteration domains are 
spatial."""
-        return all(k == "S" for k in self.dom_kind())
+    def is_gemv(self) -> bool:
+        """Whether the block is a GEMV workload."""
+        raise NotImplementedError
+
+    def is_gemm(self) -> bool:
+        """Whether the block is a GEMM workload."""
+        raise NotImplementedError
 
     def __str__(self) -> str:
         return f'BlockInfo("{self.name}", "{self.dom_kind()}", {self.dom()})'
 
     def __repr__(self) -> str:
         return str(self)
+
+
+_normalize_prim_func = get_global_func("tir.schedule.NormalizePrimFunc")
+
+
+def normalize_prim_func(sch: tir.Schedule) -> Optional[List[BlockInfo]]:
+    """Normalize the primfunc to normal form"""
+    try:
+        result = _normalize_prim_func(sch)
+        if result is None:
+            return None
+    except Exception:  # pylint: disable=broad-except
+        return None
+
+    def _iter_kind(i: tir.IterVar) -> str:
+        return {
+            tir.IterVar.DataPar: "S",
+            tir.IterVar.CommReduce: "R",
+        }.get(i.iter_type, "O")
+
+    blocks: List[BlockInfo] = []
+    for block, loops, iters, is_reduction in zip(*result):
+        blocks.append(
+            BlockInfo(
+                name=sch.get(block).name_hint,
+                iters=[
+                    IterInfo(
+                        kind=_iter_kind(iter),  # type: ignore
+                        var=iter.var,
+                        dom=iter.dom.extent,
+                        loop_rv=loop,
+                    )
+                    for loop, iter in zip(loops, iters)
+                ],
+                block_rv=block,
+                reduction_block=is_reduction,
+            )
+        )
+    return blocks
diff --git a/python/tvm/dlight/base/common_schedules.py 
b/python/tvm/dlight/base/common_schedules.py
index 6568f9e5b5..b91f46c3ca 100644
--- a/python/tvm/dlight/base/common_schedules.py
+++ b/python/tvm/dlight/base/common_schedules.py
@@ -44,7 +44,7 @@ def try_inline(
     def _trial(func: Callable):
         for i, block in enumerate(blocks):
             try:
-                func(block.block)
+                func(block.block_rv)
             except:  # pylint: disable=bare-except
                 continue
             return i
@@ -58,3 +58,38 @@ def try_inline(
             break
         blocks.pop(i)
     return blocks
+
+
+def try_inline_contiguous_spatial(sch: tir.Schedule, block_infos: 
List[BlockInfo]):
+    """Try to spatial blocks in a schedule
+
+    Parameters
+    ----------
+    sch : tir.Schedule
+        The TIR schedule used to inline blocks.
+    block_infos : List[BlockInfo]
+        The blocks to be try.
+
+    Returns
+    -------
+    remaining : List[BlockInfo]
+        The remaining blocks that cannot be inlined.
+    """
+
+    if block_infos is None:
+        return None
+    results = []
+    spatial_blocks = []
+    block: BlockInfo
+    for block in block_infos:
+        if block.is_injective():
+            spatial_blocks.append(block)
+        elif spatial_blocks:
+            results.extend(try_inline(sch, spatial_blocks))
+            results.append(block)
+            spatial_blocks = []
+        else:
+            results.append(block)
+    if spatial_blocks:
+        results.extend(try_inline(sch, spatial_blocks))
+    return results
diff --git a/python/tvm/dlight/gpu/__init__.py 
b/python/tvm/dlight/gpu/__init__.py
index 098f71d608..b689bef381 100644
--- a/python/tvm/dlight/gpu/__init__.py
+++ b/python/tvm/dlight/gpu/__init__.py
@@ -19,4 +19,5 @@ GPU-generic schedule rules.
 For CUDA/ROCm/Vulkan/Metal-specific rules, use 
`tvm.dlight.cuda/rocm/vulkan/metal` instead
 """
 from .fallback import Fallback
+from .decode_gemv import DecodeGEMV
 from .reduction import Reduction
diff --git a/python/tvm/dlight/gpu/decode_gemv.py 
b/python/tvm/dlight/gpu/decode_gemv.py
new file mode 100644
index 0000000000..18395b8063
--- /dev/null
+++ b/python/tvm/dlight/gpu/decode_gemv.py
@@ -0,0 +1,192 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""A fallback schedule rule for GPU operators."""
+# pylint: disable=invalid-name
+
+from typing import List, Optional, Union
+
+from tvm import tir
+from tvm._ffi import get_global_func
+from tvm.arith import normalize_to_iter_sum
+from tvm.ir import structural_equal
+from tvm.target import Target
+
+from ..base import ScheduleRule, normalize_prim_func, 
try_inline_contiguous_spatial
+
+
+def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]:
+    # Detect and return `Y` in `X[...] = X[...] + Y`
+    buffer_store = block.body
+    if not isinstance(buffer_store, tir.BufferStore):
+        return None
+    if not isinstance(buffer_store.value, tir.Add):
+        return None
+    if not structural_equal(
+        buffer_store.value.a,
+        tir.BufferLoad(buffer_store.buffer, block.body.indices),
+        map_free_vars=True,
+    ):
+        return None
+    return buffer_store.value.b
+
+
+def _detect_dominant_read(block: tir.Block) -> tir.PrimExpr:
+    dominant_read, read_iters = None, None
+    tir_vars = set()
+    for buffer_region in block.reads:
+        tir_vars.clear()
+
+        def _collect_tir_var(e):
+            if isinstance(e, tir.Var):
+                tir_vars.add(e)
+
+        for expr in buffer_region.region:
+            assert expr.extent == 1
+            tir.stmt_functor.post_order_visit(expr.min, _collect_tir_var)
+
+        if read_iters is None or read_iters < len(tir_vars):
+            read_iters = len(tir_vars)
+            dominant_read = buffer_region
+    assert dominant_read is not None
+    (result,) = dominant_read.buffer.offset_of([e.min for e in 
dominant_read.region])
+    return result
+
+
+class DecodeGEMV(ScheduleRule):
+    def __init__(self) -> None:
+        super().__init__()
+        self.get_loop_iter_type = 
get_global_func("tir.schedule.GetLoopIterType")
+
+    def apply(  # pylint: disable=too-many-locals
+        self,
+        func: tir.PrimFunc,
+        target: Target,
+        _: bool,
+    ) -> Union[None, tir.Schedule, List[tir.Schedule]]:
+        if not isinstance(func, tir.PrimFunc):
+            return None
+
+        if target.kind.name == "cuda":
+            len_tx, len_ty = 16, 16
+        else:
+            len_tx, len_ty = 8, 8
+
+        sch = tir.Schedule(func)
+        block_infos = try_inline_contiguous_spatial(sch, 
normalize_prim_func(sch))
+
+        if block_infos is None or len(block_infos) > 2:
+            return None
+
+        block_info = block_infos[0]
+        block = block_info.block_rv
+        block_stmt = sch.get(block)
+
+        # Step 1. Check reduction block
+        if not block_info.is_reduction():
+            return None
+        if len(block_stmt.writes) != 1:
+            return None
+        if _get_reduction_expr(block_stmt) is None:
+            return None
+
+        # Step 2. Sort out the spatial and reduction loops
+        sorted_iter_access = normalize_to_iter_sum(
+            _detect_dominant_read(block_stmt),
+            input_iters={i.var: i.dom for i in block_stmt.iter_vars},
+        )
+        if sorted_iter_access.base != 0:
+            return None
+        iter_to_info = {i.var: i for i in block_info.iters}
+        s_loops, r_loops, c_loops = [], [], []
+        for split in sorted_iter_access.args:
+            block_var = split.source.source
+            block_var_info = iter_to_info[block_var]
+            loop_rv = block_var_info.loop_rv
+            is_inner_reduction = block_var_info.kind == "R"
+            if split.lower_factor > 1:
+                c_loop_factor = split.lower_factor
+                loop_rv, c_loop = sch.split(loop_rv, factors=[None, 
c_loop_factor])
+                c_loops.append(c_loop)
+                is_loop_c_reduction = is_inner_reduction
+            if is_inner_reduction:
+                r_loops.append(loop_rv)
+            else:
+                s_loops.append(loop_rv)
+
+        if len(c_loops) > 1:
+            return None
+        if len(s_loops) != len([_ for i in block_info.iters if i.kind == "S"]):
+            return None
+        if len(s_loops) == 0 or len(r_loops) == 0:
+            return None
+
+        sch.reorder(*s_loops, *r_loops, *c_loops)
+        s = sch.fuse(*s_loops)
+        r = sch.fuse(*r_loops)
+
+        if is_inner_reduction:
+            _, tx = sch.split(r, factors=[None, len_tx * len_ty])
+            rf = sch.rfactor(tx, 0)
+            s, r, tx = sch.get_loops(rf)[:3]
+            sch.reorder(s, tx, r)
+            sch.reverse_compute_at(block, s, preserve_unit_loops=True)
+            sch.bind(tx, "threadIdx.x")
+            sch.bind(s, "blockIdx.x")
+        else:
+            sch.split(s, factors=[None, len_tx])
+            _, ty = sch.split(r, factors=[None, len_ty])
+            rf = sch.rfactor(ty, 0)
+            bx, tx, r, ty = sch.get_loops(rf)[:4]
+            sch.reorder(bx, tx, ty, r)
+            sch.reverse_compute_at(block, bx, preserve_unit_loops=True)
+            sch.bind(tx, "threadIdx.x")
+            sch.bind(ty, "threadIdx.y")
+            sch.bind(bx, "blockIdx.x")
+
+        s_loops, r_loops = [], []
+        for loop_rv in sch.get_loops(block)[1:]:
+            iter_type = self.get_loop_iter_type(sch, loop_rv)
+            if iter_type == "S":
+                s_loops.append(loop_rv)
+            elif iter_type == "R":
+                r_loops.append(loop_rv)
+            else:
+                raise RuntimeError("Unknown loop type " + str(iter_type))
+        sch.reorder(*s_loops, *r_loops)
+        s_ctr = sch.fuse(*s_loops)
+        r_ctr = sch.fuse(*r_loops)
+
+        if c_loops and not is_loop_c_reduction:
+            s_ctr, inner = sch.split(s_ctr, factors=[None, c_loop_factor])
+            sch.reorder(s_ctr, r_ctr, inner)
+
+        if is_inner_reduction:
+            sch.bind(r_ctr, "threadIdx.x")
+            sch.set_scope(rf, 0, "local")
+            sch.decompose_reduction(rf, sch.get_loops(rf)[2])
+        else:
+            sch.bind(s_ctr, "threadIdx.x")
+            sch.bind(r_ctr, "threadIdx.y")
+            sch.set_scope(rf, 0, "local")
+            sch.decompose_reduction(rf, sch.get_loops(rf)[3])
+
+        if len(block_infos) == 2:
+            sch.set_scope(block, 0, "local")
+            sch.reverse_compute_at(block_infos[1].block_rv, 
sch.get_loops(block)[0])
+
+        return sch
diff --git a/python/tvm/dlight/gpu/fallback.py 
b/python/tvm/dlight/gpu/fallback.py
index caefc8d563..63033aa7c7 100644
--- a/python/tvm/dlight/gpu/fallback.py
+++ b/python/tvm/dlight/gpu/fallback.py
@@ -21,7 +21,7 @@ from typing import List
 from tvm import tir
 from tvm.target import Target
 
-from ..base import BlockInfo, ScheduleRule, try_inline
+from ..base import ScheduleRule, normalize_prim_func, try_inline
 
 
 def _max_threads_per_block(target: Target) -> int:
@@ -49,21 +49,13 @@ class Fallback(ScheduleRule):
         max_threads_per_block = _max_threads_per_block(target)
 
         sch = tir.Schedule(func)
-        for block in try_inline(
-            sch,
-            [
-                BlockInfo(
-                    sch,
-                    block,
-                )
-                for block in sch.get_child_blocks(sch.get_block("root"))
-            ],
-        ):
+        block_infos = try_inline(sch, normalize_prim_func(sch))
+        for block in block_infos:
             s_loops: List[tir.schedule.LoopRV] = []
             r_loops: List[tir.schedule.LoopRV] = []
             o_loops: List[tir.schedule.LoopRV] = []
             dom_kind = block.dom_kind()
-            block = block.block
+            block = block.block_rv
             for loop, iter_type in zip(sch.get_loops(block), dom_kind):
                 {"S": s_loops, "R": r_loops, "O": 
o_loops}[iter_type].append(loop)
 
diff --git a/python/tvm/dlight/gpu/reduction.py 
b/python/tvm/dlight/gpu/reduction.py
index b3cc58c902..bfca76546f 100644
--- a/python/tvm/dlight/gpu/reduction.py
+++ b/python/tvm/dlight/gpu/reduction.py
@@ -20,7 +20,7 @@ from typing import List, Union
 from tvm import tir
 from tvm.target import Target
 
-from ..base import BlockInfo, ScheduleRule, try_inline
+from ..base import ScheduleRule, normalize_prim_func, 
try_inline_contiguous_spatial
 
 
 class Reduction(ScheduleRule):
@@ -39,47 +39,31 @@ class Reduction(ScheduleRule):
             len_tx = 64
             unroll_depth = 64
 
-        def _inline_all_spatial():
-            blocks = []
-            spatial_blocks = []
-            for block in sch.get_child_blocks(sch.get_block("root")):
-                block = BlockInfo(sch, block)
-                if block.is_spatial():
-                    spatial_blocks.append(block)
-                elif spatial_blocks:
-                    blocks.extend(try_inline(sch, spatial_blocks))
-                    blocks.append(block)
-                    spatial_blocks = []
-                else:
-                    blocks.append(block)
-            if spatial_blocks:
-                blocks.extend(try_inline(sch, spatial_blocks))
-            return blocks
-
         sch = tir.Schedule(func)
-        blocks = _inline_all_spatial()
-        assert len(blocks) > 0
+        block_infos = normalize_prim_func(sch)
+        block_infos = try_inline_contiguous_spatial(sch, block_infos)
+        assert len(block_infos) > 0
 
-        dom_kind = blocks[0].dom_kind()
+        dom_kind = block_infos[0].dom_kind()
         num_leading_s = len(dom_kind) - len(dom_kind.lstrip("S"))
         num_trailing_r = len(dom_kind) - len(dom_kind.rstrip("R"))
         try:
-            for block in blocks[1:-1]:
+            for block in block_infos[1:-1]:
                 assert block.dom_kind() == dom_kind
-            assert blocks[-1].is_spatial()
-            assert len(blocks[-1].dom_kind()) == len(dom_kind)
+            assert block_infos[-1].is_injective()
+            assert len(block_infos[-1].dom_kind()) == len(dom_kind)
         except AssertionError:
             print("Mismatch")
             return None
 
-        loops = sch.get_loops(blocks[-1].block)
+        loops = sch.get_loops(block_infos[-1].block_rv)
         bx = sch.fuse(*loops[:num_leading_s])  # pylint: disable=invalid-name
         _, tx = sch.split(loops[-1], [None, len_tx])  # pylint: 
disable=invalid-name
         sch.bind(bx, "blockIdx.x")
         sch.bind(tx, "threadIdx.x")
 
-        for block in reversed(blocks[:-1]):
-            block = block.block
+        for block in reversed(block_infos[:-1]):
+            block = block.block_rv
             for i, _ in enumerate(sch.get(block).writes):
                 sch.set_scope(block, buffer_index=i, storage_scope="shared")
             sch.compute_at(block, bx, preserve_unit_loops=True)
diff --git a/src/tir/schedule/analysis/analysis.cc 
b/src/tir/schedule/analysis/analysis.cc
index 5588458235..bab0d4a3a9 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -328,6 +328,11 @@ bool IsReductionBlock(const ScheduleState& self, const 
StmtSRef& block_sref,
   return CheckReductionBlockErrorCode(self, block_sref, scope_root_sref) == 0;
 }
 
+TVM_REGISTER_GLOBAL("tir.schedule.IsReductionBlock")
+    .set_body_typed([](Schedule sch, BlockRV block_rv, BlockRV scope_block_rv) 
{
+      return IsReductionBlock(sch->state(), sch->GetSRef(block_rv), 
sch->GetSRef(scope_block_rv));
+    });
+
 void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref,
                          const StmtSRef& scope_root_sref) {
   class NotReductionBlockError : public ScheduleError {
@@ -859,6 +864,11 @@ BlockRealize GetBlockRealize(const ScheduleState& self, 
const StmtSRef& block_sr
   }
 }
 
+TVM_REGISTER_GLOBAL("tir.schedule.GetBlockRealize")
+    .set_body_typed([](Schedule sch, BlockRV block_rv) {
+      return GetBlockRealize(sch->state(), sch->GetSRef(block_rv));
+    });
+
 IterVarType GetLoopIterType(const StmtSRef& loop_sref) {
   const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
   const Var& loop_var = loop->loop_var;
@@ -1459,6 +1469,11 @@ bool IsTrivialBinding(const ScheduleState& self, const 
StmtSRef& block_sref) {
   return true;
 }
 
+TVM_REGISTER_GLOBAL("tir.schedule.IsTrivialBinding")
+    .set_body_typed([](Schedule sch, BlockRV block_rv) {
+      return IsTrivialBinding(sch->state(), sch->GetSRef(block_rv));
+    });
+
 bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& 
block_sref) {
   if (HasBeenMultiLevelTiled(block_sref)) {
     return false;
diff --git a/src/tir/schedule/primitive/compute_inline.cc 
b/src/tir/schedule/primitive/compute_inline.cc
index 453f962aa5..d6be0e5805 100644
--- a/src/tir/schedule/primitive/compute_inline.cc
+++ b/src/tir/schedule/primitive/compute_inline.cc
@@ -177,14 +177,18 @@ class NonSingleProducerError : public ScheduleError {
         self, GetRef<Block>(consumer_block), scope_root_sref);
     class ProducerFinder : public StmtVisitor {
      public:
-      static std::vector<Block> GetProducer(const Buffer& buffer, const Block& 
scope_block) {
-        ProducerFinder finder(buffer);
+      static std::vector<Block> GetProducer(const ScheduleState& self,
+                                            const StmtSRef& scope_root_sref, 
const Buffer& buffer,
+                                            const Block& scope_block) {
+        ProducerFinder finder(self, scope_root_sref, buffer);
         finder(scope_block);
         return finder.producer_across_scope_.back();
       }
 
      private:
-      explicit ProducerFinder(const Buffer& buffer) : buffer_(buffer) {
+      explicit ProducerFinder(const ScheduleState& self, const StmtSRef& 
scope_root_sref,
+                              const Buffer& buffer)
+          : self_(self), scope_root_sref_(scope_root_sref), buffer_(buffer) {
         producer_across_scope_.push_back({});
       }
 
@@ -204,16 +208,23 @@ class NonSingleProducerError : public ScheduleError {
         producer_across_scope_.pop_back();
         for (const auto& write : node->writes) {
           if (write->buffer.same_as(buffer_)) {
+            // Check if the producer block is a complete block
+            StmtSRef producer_block_sref = self_->stmt2ref.at(node);
+            if (!IsCompleteBlock(self_, producer_block_sref, 
scope_root_sref_)) {
+              throw NonSingleProducerError(self_->mod, GetRef<Block>(node));
+            }
             producer_across_scope_.back().push_back(GetRef<Block>(node));
             break;
           }
         }
       }
+      ScheduleState self_;
+      StmtSRef scope_root_sref_;
       Buffer buffer_;
       std::vector<std::vector<Block>> producer_across_scope_;
     };
-    std::vector<Block> producer_across_scope =
-        ProducerFinder::GetProducer(consumer_buffer, 
GetRef<Block>(scope_block));
+    std::vector<Block> producer_across_scope = ProducerFinder::GetProducer(
+        self, scope_root_sref, consumer_buffer, GetRef<Block>(scope_block));
     if (producer_across_scope.size() != 1) {
       throw NonSingleProducerError(self->mod, GetRef<Block>(consumer_block));
     }
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index e4047273b6..35bf7b7669 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -429,5 +429,57 @@ PrimExpr BlockBufferAccessSimplifier::VisitExpr_(const 
BufferLoadNode* op) {
   return std::move(node);
 }
 
+/******** PrimFunc-level analysis and transformation ********/
+
+Optional<ObjectRef> NormalizePrimFunc(Schedule sch) {
+  BlockRV root_block = sch->GetBlock("root");
+  Array<BlockRV> blocks = sch->GetChildBlocks(root_block);
+  for (const BlockRV& block : blocks) {
+    StmtSRef block_sref = sch->GetSRef(block);
+    Array<StmtSRef> loops = GetLoops(block_sref);
+    Array<PrimExpr> binds = GetBlockRealize(sch->state(), 
block_sref)->iter_values;
+    if (loops.size() != binds.size()) {
+      return NullOpt;
+    }
+    for (int i = 0, n = loops.size(); i < n; ++i) {
+      const ForNode* loop = TVM_SREF_TO_FOR(loops[i]);
+      if (binds[i].get() != loop->loop_var.get()) {
+        return NullOpt;
+      }
+      if (!is_zero(loop->min)) {
+        return NullOpt;
+      }
+    }
+  }
+  Array<Array<LoopRV>> block_loops;
+  Array<Array<IterVar>> block_iters;
+  Array<IntImm> block_is_reduction;
+  for (const BlockRV& block : blocks) {
+    Array<IterVar> iters = sch->Get(block)->iter_vars;
+    Array<Var> index_map_inputs;
+    Array<PrimExpr> index_map_outputs;
+    for (const IterVar& iter : sch->Get(block)->iter_vars) {
+      Var var = iter->var.copy_with_suffix("");
+      index_map_inputs.push_back(var);
+      if (!is_one(iter->dom->extent)) {
+        index_map_outputs.push_back(var);
+      }
+    }
+    if (index_map_outputs.empty()) {
+      index_map_outputs.push_back(make_zero(DataType::Int(64)));
+    }
+    sch->TransformBlockLayout(block, IndexMap(index_map_inputs, 
index_map_outputs));
+    block_loops.push_back(sch->GetLoops(block));
+    block_iters.push_back(sch->Get(block)->iter_vars);
+    bool is_reduction = IsReductionBlock(sch->state(),         //
+                                         sch->GetSRef(block),  //
+                                         sch->GetSRef(root_block));
+    block_is_reduction.push_back(Bool(is_reduction));
+  }
+  return Array<ObjectRef>{blocks, block_loops, block_iters, 
block_is_reduction};
+}
+
+TVM_REGISTER_GLOBAL("tir.schedule.NormalizePrimFunc").set_body_typed(NormalizePrimFunc);
+
 }  // namespace tir
 }  // namespace tvm
diff --git a/tests/python/dlight/test_gpu_decode_gemv.py 
b/tests/python/dlight/test_gpu_decode_gemv.py
new file mode 100644
index 0000000000..46232a461e
--- /dev/null
+++ b/tests/python/dlight/test_gpu_decode_gemv.py
@@ -0,0 +1,432 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import tvm
+from tvm import dlight as dl
+from tvm.ir import assert_structural_equal
+from tvm.script import ir as I
+from tvm.script import tir as T
+from tvm.target import Target
+
+
+def test_decode_gemv_1():
+    # NK layout + K as decode dim
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), 
"float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), 
"float16")):
+            T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            B = T.alloc_buffer((4096, 4096), "float16")
+            for i, j in T.grid(4096, 4096):
+                with T.block("decode"):
+                    v_i, v_j = T.axis.remap("SS", [i, j])
+                    T.reads(W[v_i, v_j // 8], S[v_i, v_j // 32])
+                    T.writes(B[v_i, v_j])
+                    B[v_i, v_j] = (T.Cast("float16", 
T.bitwise_and(T.shift_right(W[v_i, v_j // 8], T.Cast("uint32", v_j % 8) * 
T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i, v_j // 32]
+            for i0, i1, i2, k in T.grid(1, 1, 4096, 4096):
+                with T.block("matmul"):
+                    v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, 
k])
+                    T.reads(V[v_i0, v_i1, v_k], B[v_i2, v_k])
+                    T.writes(C[v_i0, v_i1, v_i2])
+                    with T.init():
+                        C[v_i0, v_i1, v_i2] = T.float16(0)
+                    C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + V[v_i0, v_i1, 
v_k] * B[v_i2, v_k]
+
+
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), 
"float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), 
"float16")):
+            T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, 
"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            C_rf_local = T.alloc_buffer((256, 1, 1, 4096), "float16", 
scope="local")
+            for i2_i0_i1_fused in T.thread_binding(4096, thread="blockIdx.x"):
+                for k_0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
+                    with T.block("matmul_rf_init"):
+                        vk_0_fused_1 = T.axis.spatial(256, k_0_fused_1)
+                        v_i2 = T.axis.spatial(4096, i2_i0_i1_fused)
+                        C_rf_local[vk_0_fused_1, 0, 0, v_i2] = T.float16(0)
+                    for k_0_fused_0, k_1 in T.grid(2, 8):
+                        with T.block("matmul_rf_update"):
+                            vk_0_fused_1 = T.axis.spatial(256, k_0_fused_1)
+                            v_i2, vk_0_fused_0, vk_1 = T.axis.remap("SRR", 
[i2_i0_i1_fused, k_0_fused_0, k_1])
+                            C_rf_local[vk_0_fused_1, 0, 0, v_i2] = 
C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, vk_0_fused_0 * 2048 + 
vk_0_fused_1 * 8 + vk_1] * ((T.Cast("float16", 
T.bitwise_and(T.shift_right(W[v_i2, (vk_0_fused_0 * 2048 + vk_0_fused_1 * 8 + 
vk_1) // 8], T.Cast("uint32", (vk_0_fused_0 * 2048 + vk_0_fused_1 * 8 + vk_1) % 
8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i2, (vk_0_fused_0 * 
2048 + vk_0_fused_1 * 8 + vk_1) // 32])
+                for ax1_ax2_ax3_fused in range(1):
+                    for ax0_fused in T.thread_binding(256, 
thread="threadIdx.x"):
+                        with T.block("matmul"):
+                            vk_0_fused_1 = T.axis.reduce(256, ax0_fused)
+                            v_i2 = T.axis.spatial(4096, i2_i0_i1_fused)
+                            with T.init():
+                                C[0, 0, v_i2] = T.float16(0)
+                            C[0, 0, v_i2] = C[0, 0, v_i2] + 
C_rf_local[vk_0_fused_1, 0, 0, v_i2]
+    # fmt: on
+
+    target = Target("nvidia/geforce-rtx-3090-ti")
+    with target:
+        mod = dl.ApplyDefaultSchedule(dl.gpu.DecodeGEMV())(Before)  # pylint: 
disable=not-callable
+    assert_structural_equal(mod, After)
+
+
+def test_decode_gemv_2():
+    # KN layout + K as decode dim
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def func(W: T.Buffer((512, 4096), "uint32"), S: T.Buffer((128, 4096), 
"float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), 
"float16")):
+            T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            B = T.alloc_buffer((4096, 4096), "float16")
+            for i, j in T.grid(4096, 4096):
+                with T.block("decode"):
+                    v_i, v_j = T.axis.remap("SS", [i, j])
+                    T.reads(W[v_i // 8, v_j], S[v_i // 32, v_j])
+                    T.writes(B[v_i, v_j])
+                    B[v_i, v_j] = (T.Cast("float16", 
T.bitwise_and(T.shift_right(W[v_i // 8, v_j], T.Cast("uint32", v_i % 8) * 
T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i // 32, v_j]
+            for i0, i1, i2, k in T.grid(1, 1, 4096, 4096):
+                with T.block("matmul"):
+                    v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, 
k])
+                    T.reads(V[v_i0, v_i1, v_k], B[v_k, v_i2])
+                    T.writes(C[v_i0, v_i1, v_i2])
+                    with T.init():
+                        C[v_i0, v_i1, v_i2] = T.float16(0)
+                    C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + V[v_i0, v_i1, 
v_k] * B[v_k, v_i2]
+
+
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def func(W: T.Buffer((512, 4096), "uint32"), S: T.Buffer((128, 4096), 
"float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), 
"float16")):
+            T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, 
"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            C_rf_local = T.alloc_buffer((16, 1, 1, 4096), "float16", 
scope="local")
+            for i2_i0_i1_fused_0 in T.thread_binding(256, thread="blockIdx.x"):
+                for i2_i0_i1_fused_1 in T.thread_binding(16, 
thread="threadIdx.x"):
+                    for k_0_fused_1 in T.thread_binding(16, 
thread="threadIdx.y"):
+                        with T.block("matmul_rf_init"):
+                            vk_0_fused_1 = T.axis.spatial(16, k_0_fused_1)
+                            v_i2 = T.axis.spatial(4096, i2_i0_i1_fused_0 * 16 
+ i2_i0_i1_fused_1)
+                            C_rf_local[vk_0_fused_1, 0, 0, v_i2] = T.float16(0)
+                        for k_0_fused_0, k_1 in T.grid(32, 8):
+                            with T.block("matmul_rf_update"):
+                                vk_0_fused_1 = T.axis.spatial(16, k_0_fused_1)
+                                v_i2 = T.axis.spatial(4096, i2_i0_i1_fused_0 * 
16 + i2_i0_i1_fused_1)
+                                vk_0_fused_0, vk_1 = T.axis.remap("RR", 
[k_0_fused_0, k_1])
+                                C_rf_local[vk_0_fused_1, 0, 0, v_i2] = 
C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, vk_0_fused_0 * 128 + 
vk_0_fused_1 * 8 + vk_1] * ((T.Cast("float16", 
T.bitwise_and(T.shift_right(W[(vk_0_fused_0 * 128 + vk_0_fused_1 * 8 + vk_1) // 
8, v_i2], T.Cast("uint32", (vk_0_fused_0 * 128 + vk_0_fused_1 * 8 + vk_1) % 8) 
* T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[(vk_0_fused_0 * 128 + 
vk_0_fused_1 * 8 + vk_1) // 32, v_i2])
+                for ax1_ax2_ax3_fused in T.thread_binding(16, 
thread="threadIdx.x"):
+                    for ax0_fused in T.thread_binding(16, 
thread="threadIdx.y"):
+                        with T.block("matmul"):
+                            vk_0_fused_1 = T.axis.reduce(16, ax0_fused)
+                            v_i2 = T.axis.spatial(4096, i2_i0_i1_fused_0 * 16 
+ ax1_ax2_ax3_fused)
+                            with T.init():
+                                C[0, 0, v_i2] = T.float16(0)
+                            C[0, 0, v_i2] = C[0, 0, v_i2] + 
C_rf_local[vk_0_fused_1, 0, 0, v_i2]
+
+    # fmt: on
+
+    target = Target("nvidia/geforce-rtx-3090-ti")
+    with target:
+        mod = dl.ApplyDefaultSchedule(dl.gpu.DecodeGEMV())(Before)  # pylint: 
disable=not-callable
+    assert_structural_equal(mod, After)
+
+
+def test_decode_gemv_3():
+    # NK layout + N as decode dim
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def func(W: T.Buffer((512, 4096), "uint32"), S: T.Buffer((128, 4096), 
"float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), 
"float16")):
+            T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            B = T.alloc_buffer((4096, 4096), "float16")
+            for i, j in T.grid(4096, 4096):
+                with T.block("decode"):
+                    v_i, v_j = T.axis.remap("SS", [i, j])
+                    T.reads(W[v_i // 8, v_j], S[v_i // 32, v_j])
+                    T.writes(B[v_i, v_j])
+                    B[v_i, v_j] = (T.Cast("float16", 
T.bitwise_and(T.shift_right(W[v_i // 8, v_j], T.Cast("uint32", v_i % 8) * 
T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i // 32, v_j]
+            for i0, i1, i2, k in T.grid(1, 1, 4096, 4096):
+                with T.block("matmul"):
+                    v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, 
k])
+                    T.reads(V[v_i0, v_i1, v_k], B[v_i2, v_k])
+                    T.writes(C[v_i0, v_i1, v_i2])
+                    with T.init():
+                        C[v_i0, v_i1, v_i2] = T.float16(0)
+                    C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + V[v_i0, v_i1, 
v_k] * B[v_i2, v_k]
+
+
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def func(W: T.Buffer((512, 4096), "uint32"), S: T.Buffer((128, 4096), 
"float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), 
"float16")):
+            T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, 
"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            C_rf_local = T.alloc_buffer((256, 1, 1, 4096), "float16", 
scope="local")
+            for i2_0_i0_i1_fused in T.thread_binding(512, thread="blockIdx.x"):
+                for k_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
+                    for i2_1_init in range(8):
+                        with T.block("matmul_rf_init"):
+                            vk_fused_1 = T.axis.spatial(256, k_fused_1)
+                            v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused * 8 + 
i2_1_init)
+                            C_rf_local[vk_fused_1, 0, 0, v_i2] = T.float16(0)
+                    for k_fused_0, i2_1 in T.grid(16, 8):
+                        with T.block("matmul_rf_update"):
+                            vk_fused_1 = T.axis.spatial(256, k_fused_1)
+                            v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused * 8 + 
i2_1)
+                            vk_fused_0 = T.axis.reduce(16, k_fused_0)
+                            C_rf_local[vk_fused_1, 0, 0, v_i2] = 
C_rf_local[vk_fused_1, 0, 0, v_i2] + V[0, 0, vk_fused_0 * 256 + vk_fused_1] * 
((T.Cast("float16", T.bitwise_and(T.shift_right(W[v_i2 // 8, vk_fused_0 * 256 + 
vk_fused_1], T.Cast("uint32", v_i2 % 8) * T.uint32(4)), T.uint32(15))) - 
T.float16(7)) * S[v_i2 // 32, vk_fused_0 * 256 + vk_fused_1])
+                for ax1_ax2_ax3_fused_0 in range(1):
+                    for ax0_fused in T.thread_binding(256, 
thread="threadIdx.x"):
+                        for ax1_ax2_ax3_fused_1 in range(8):
+                            with T.block("matmul"):
+                                vk_fused_1 = T.axis.reduce(256, ax0_fused)
+                                v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused * 
8 + ax1_ax2_ax3_fused_0 * 8 + ax1_ax2_ax3_fused_1)
+                                with T.init():
+                                    C[0, 0, v_i2] = T.float16(0)
+                                C[0, 0, v_i2] = C[0, 0, v_i2] + 
C_rf_local[vk_fused_1, 0, 0, v_i2]
+
+    # fmt: on
+
+    target = Target("nvidia/geforce-rtx-3090-ti")
+    with target:
+        mod = dl.ApplyDefaultSchedule(dl.gpu.DecodeGEMV())(Before)  # pylint: 
disable=not-callable
+    assert_structural_equal(mod, After)
+
+
+def test_decode_gemv_4():
+    # KN layout + N as decode dim
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), 
"float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), 
"float16")):
+            T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            B = T.alloc_buffer((4096, 4096), "float16")
+            for i, j in T.grid(4096, 4096):
+                with T.block("decode"):
+                    v_i, v_j = T.axis.remap("SS", [i, j])
+                    T.reads(W[v_i, v_j // 8], S[v_i, v_j // 32])
+                    T.writes(B[v_i, v_j])
+                    B[v_i, v_j] = (T.Cast("float16", 
T.bitwise_and(T.shift_right(W[v_i, v_j // 8], T.Cast("uint32", v_j % 8) * 
T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i, v_j // 32]
+            for i0, i1, i2, k in T.grid(1, 1, 4096, 4096):
+                with T.block("matmul"):
+                    v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, 
k])
+                    T.reads(V[v_i0, v_i1, v_k], B[v_k, v_i2])
+                    T.writes(C[v_i0, v_i1, v_i2])
+                    with T.init():
+                        C[v_i0, v_i1, v_i2] = T.float16(0)
+                    C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + V[v_i0, v_i1, 
v_k] * B[v_k, v_i2]
+
+
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), 
"float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), 
"float16")):
+            T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, 
"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            C_rf_local = T.alloc_buffer((16, 1, 1, 4096), "float16", 
scope="local")
+            for i2_0_i0_i1_fused_0 in T.thread_binding(32, 
thread="blockIdx.x"):
+                for i2_0_i0_i1_fused_1 in T.thread_binding(16, 
thread="threadIdx.x"):
+                    for k_fused_1 in T.thread_binding(16, 
thread="threadIdx.y"):
+                        for i2_1_init in range(8):
+                            with T.block("matmul_rf_init"):
+                                vk_fused_1 = T.axis.spatial(16, k_fused_1)
+                                v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused_0 
* 128 + i2_0_i0_i1_fused_1 * 8 + i2_1_init)
+                                C_rf_local[vk_fused_1, 0, 0, v_i2] = 
T.float16(0)
+                        for k_fused_0, i2_1 in T.grid(256, 8):
+                            with T.block("matmul_rf_update"):
+                                vk_fused_1 = T.axis.spatial(16, k_fused_1)
+                                v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused_0 
* 128 + i2_0_i0_i1_fused_1 * 8 + i2_1)
+                                vk_fused_0 = T.axis.reduce(256, k_fused_0)
+                                C_rf_local[vk_fused_1, 0, 0, v_i2] = 
C_rf_local[vk_fused_1, 0, 0, v_i2] + V[0, 0, vk_fused_0 * 16 + vk_fused_1] * 
((T.Cast("float16", T.bitwise_and(T.shift_right(W[vk_fused_0 * 16 + vk_fused_1, 
v_i2 // 8], T.Cast("uint32", v_i2 % 8) * T.uint32(4)), T.uint32(15))) - 
T.float16(7)) * S[vk_fused_0 * 16 + vk_fused_1, v_i2 // 32])
+                for ax1_ax2_ax3_fused_0 in T.thread_binding(16, 
thread="threadIdx.x"):
+                    for ax0_fused in T.thread_binding(16, 
thread="threadIdx.y"):
+                        for ax1_ax2_ax3_fused_1 in range(8):
+                            with T.block("matmul"):
+                                vk_fused_1 = T.axis.reduce(16, ax0_fused)
+                                v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused_0 
* 128 + ax1_ax2_ax3_fused_0 * 8 + ax1_ax2_ax3_fused_1)
+                                with T.init():
+                                    C[0, 0, v_i2] = T.float16(0)
+                                C[0, 0, v_i2] = C[0, 0, v_i2] + 
C_rf_local[vk_fused_1, 0, 0, v_i2]
+
+    # fmt: on
+
+    target = Target("nvidia/geforce-rtx-3090-ti")
+    with target:
+        mod = dl.ApplyDefaultSchedule(dl.gpu.DecodeGEMV())(Before)  # pylint: 
disable=not-callable
+    assert_structural_equal(mod, After)
+
+
+def test_decode_gemv_sigmoid():
+    # NK layout + K as decode dim
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), 
"float16"), V: T.Buffer((1, 1, 4096), "float16"), D: T.Buffer((1, 1, 4096), 
"float16")):
+            T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            B = T.alloc_buffer((4096, 4096), "float16")
+            C = T.alloc_buffer((1, 1, 4096), "float16")
+            for i, j in T.grid(4096, 4096):
+                with T.block("decode"):
+                    v_i, v_j = T.axis.remap("SS", [i, j])
+                    T.reads(W[v_i, v_j // 8], S[v_i, v_j // 32])
+                    T.writes(B[v_i, v_j])
+                    B[v_i, v_j] = (T.Cast("float16", 
T.bitwise_and(T.shift_right(W[v_i, v_j // 8], T.Cast("uint32", v_j % 8) * 
T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i, v_j // 32]
+            for i0, i1, i2, k in T.grid(1, 1, 4096, 4096):
+                with T.block("matmul"):
+                    v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, 
k])
+                    T.reads(V[v_i0, v_i1, v_k], B[v_i2, v_k])
+                    T.writes(C[v_i0, v_i1, v_i2])
+                    with T.init():
+                        C[v_i0, v_i1, v_i2] = T.float16(0)
+                    C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + V[v_i0, v_i1, 
v_k] * B[v_i2, v_k]
+            for i0, i1, i2 in T.grid(1, 1, 4096):
+                with T.block("sigmoid"):
+                    v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+                    T.reads(C[v_i0, v_i1, v_i2])
+                    T.writes(D[v_i0, v_i1, v_i2])
+                    D[v_i0, v_i1, v_i2] = T.sigmoid(C[v_i0, v_i1, v_i2])
+
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), 
"float16"), V: T.Buffer((1, 1, 4096), "float16"), D: T.Buffer((1, 1, 4096), 
"float16")):
+            T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, 
"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            C_local = T.alloc_buffer((1, 1, 4096), "float16", scope="local")
+            C_rf_local = T.alloc_buffer((256, 1, 1, 4096), "float16", 
scope="local")
+            for i2_i0_i1_fused in T.thread_binding(4096, thread="blockIdx.x"):
+                for k_0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
+                    with T.block("matmul_rf_init"):
+                        vk_0_fused_1 = T.axis.spatial(256, k_0_fused_1)
+                        v_i2 = T.axis.spatial(4096, i2_i0_i1_fused)
+                        C_rf_local[vk_0_fused_1, 0, 0, v_i2] = T.float16(0)
+                    for k_0_fused_0, k_1 in T.grid(2, 8):
+                        with T.block("matmul_rf_update"):
+                            vk_0_fused_1 = T.axis.spatial(256, k_0_fused_1)
+                            v_i2, vk_0_fused_0, vk_1 = T.axis.remap("SRR", 
[i2_i0_i1_fused, k_0_fused_0, k_1])
+                            C_rf_local[vk_0_fused_1, 0, 0, v_i2] = 
C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, vk_0_fused_0 * 2048 + 
vk_0_fused_1 * 8 + vk_1] * ((T.Cast("float16", 
T.bitwise_and(T.shift_right(W[v_i2, (vk_0_fused_0 * 2048 + vk_0_fused_1 * 8 + 
vk_1) // 8], T.Cast("uint32", (vk_0_fused_0 * 2048 + vk_0_fused_1 * 8 + vk_1) % 
8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i2, (vk_0_fused_0 * 
2048 + vk_0_fused_1 * 8 + vk_1) // 32])
+                for ax1_ax2_ax3_fused in range(1):
+                    for ax0_fused in T.thread_binding(256, 
thread="threadIdx.x"):
+                        with T.block("matmul"):
+                            vk_0_fused_1 = T.axis.reduce(256, ax0_fused)
+                            v_i2 = T.axis.spatial(4096, i2_i0_i1_fused)
+                            with T.init():
+                                C_local[0, 0, v_i2] = T.float16(0)
+                            C_local[0, 0, v_i2] = C_local[0, 0, v_i2] + 
C_rf_local[vk_0_fused_1, 0, 0, v_i2]
+                with T.block("sigmoid"):
+                    v_i2 = T.axis.spatial(4096, i2_i0_i1_fused)
+                    D[0, 0, v_i2] = T.sigmoid(C_local[0, 0, v_i2])
+
+    # fmt: on
+
+    target = Target("nvidia/geforce-rtx-3090-ti")
+    with target:
+        mod = dl.ApplyDefaultSchedule(dl.gpu.DecodeGEMV())(Before)  # pylint: 
disable=not-callable
+    assert_structural_equal(mod, After)
+
+
+def test_decode_gemv_1_fp32():
+    # NK layout + K as decode dim
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), 
"float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), 
"float16")):
+            T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            B = T.alloc_buffer((4096, 4096), "float16")
+            C_fp32 = T.alloc_buffer((1, 1, 4096), "float32")
+            for i, j in T.grid(4096, 4096):
+                with T.block("decode"):
+                    v_i, v_j = T.axis.remap("SS", [i, j])
+                    T.reads(W[v_i, v_j // 8], S[v_i, v_j // 32])
+                    T.writes(B[v_i, v_j])
+                    B[v_i, v_j] = (T.Cast("float16", 
T.bitwise_and(T.shift_right(W[v_i, v_j // 8], T.Cast("uint32", v_j % 8) * 
T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i, v_j // 32]
+            for i0, i1, i2, k in T.grid(1, 1, 4096, 4096):
+                with T.block("matmul"):
+                    v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, 
k])
+                    T.reads(V[v_i0, v_i1, v_k], B[v_i2, v_k])
+                    T.writes(C_fp32[v_i0, v_i1, v_i2])
+                    with T.init():
+                        C_fp32[v_i0, v_i1, v_i2] = T.float16(0)
+                    C_fp32[v_i0, v_i1, v_i2] = C_fp32[v_i0, v_i1, v_i2] + 
T.Cast("float32", V[v_i0, v_i1, v_k]) * T.Cast("float32", B[v_i2, v_k])
+            for i0, i1, i2 in T.grid(1, 1, 4096):
+                with T.block("cast"):
+                    v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+                    T.reads(C_fp32[v_i0, v_i1, v_i2])
+                    T.writes(C[v_i0, v_i1, v_i2])
+                    C[v_i0, v_i1, v_i2] = T.Cast("float16", C_fp32[v_i0, v_i1, 
v_i2])
+
+
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), 
"float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), 
"float16")):
+            T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, 
"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            C_fp32_local = T.alloc_buffer((1, 1, 4096), scope="local")
+            C_fp32_rf_local = T.alloc_buffer((256, 1, 1, 4096), scope="local")
+            for ax0_fused in T.thread_binding(4096, thread="blockIdx.x"):
+                for ax1_0_fused_1 in T.thread_binding(256, 
thread="threadIdx.x"):
+                    with T.block("matmul_rf_init"):
+                        vax1_0_fused_1, v0 = T.axis.remap("SS", 
[ax1_0_fused_1, ax0_fused])
+                        T.reads()
+                        T.writes(C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0])
+                        C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] = 
T.float32(0)
+                    for ax1_0_fused_0, ax1_1 in T.grid(2, 8):
+                        with T.block("matmul_rf_update"):
+                            vax1_0_fused_1, v0, vax1_0_fused_0, vax1_1 = 
T.axis.remap("SSRR", [ax1_0_fused_1, ax0_fused, ax1_0_fused_0, ax1_1])
+                            T.reads(C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0], 
V[0, 0, vax1_0_fused_0 * 2048 + vax1_0_fused_1 * 8 + vax1_1], W[v0, 
(vax1_0_fused_0 * 2048 + vax1_0_fused_1 * 8 + vax1_1) // 8], S[v0, 
(vax1_0_fused_0 * 2048 + vax1_0_fused_1 * 8 + vax1_1) // 32])
+                            T.writes(C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0])
+                            C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] = 
C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] + T.Cast("float32", V[0, 0, 
vax1_0_fused_0 * 2048 + vax1_0_fused_1 * 8 + vax1_1]) * T.Cast("float32", 
(T.Cast("float16", T.bitwise_and(T.shift_right(W[v0, (vax1_0_fused_0 * 2048 + 
vax1_0_fused_1 * 8 + vax1_1) // 8], T.Cast("uint32", (vax1_0_fused_0 * 2048 + 
vax1_0_fused_1 * 8 + vax1_1) % 8) * T.uint32(4)), T.uint32(15))) - 
T.float16(7)) * S[v0, (vax1_0_fused_0 * 2048 + vax1_0 [...]
+                for ax1_fused in range(1):
+                    for ax0_fused_1 in T.thread_binding(256, 
thread="threadIdx.x"):
+                        with T.block("matmul"):
+                            vax1_0_fused_1, v0 = T.axis.remap("RS", 
[ax0_fused_1, ax0_fused])
+                            T.reads(C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0])
+                            T.writes(C_fp32_local[0, 0, v0])
+                            with T.init():
+                                C_fp32_local[0, 0, v0] = T.float32(0)
+                            C_fp32_local[0, 0, v0] = C_fp32_local[0, 0, v0] + 
C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0]
+                with T.block("cast"):
+                    v0 = T.axis.spatial(4096, ax0_fused)
+                    T.reads(C_fp32_local[0, 0, v0])
+                    T.writes(C[0, 0, v0])
+                    C[0, 0, v0] = T.Cast("float16", C_fp32_local[0, 0, v0])
+
+    # fmt: on
+
+    target = Target("nvidia/geforce-rtx-3090-ti")
+    with target:
+        mod = dl.ApplyDefaultSchedule(dl.gpu.DecodeGEMV())(Before)  # pylint: 
disable=not-callable
+    assert_structural_equal(mod, After)
+
+
+if __name__ == "__main__":
+    test_decode_gemv_1()
+    test_decode_gemv_2()
+    test_decode_gemv_3()
+    test_decode_gemv_4()
+    test_decode_gemv_sigmoid()
+    test_decode_gemv_1_fp32()
diff --git a/tests/python/dlight/test_gpu_fallback.py 
b/tests/python/dlight/test_gpu_fallback.py
index 8d20169c53..38e9a391dc 100644
--- a/tests/python/dlight/test_gpu_fallback.py
+++ b/tests/python/dlight/test_gpu_fallback.py
@@ -48,16 +48,13 @@ def test_fallback():
             C: T.Buffer((1, 1, 4096), "float16"),
         ):
             T.func_attr({"tir.is_scheduled": 1})
-            # with T.block("root"):
-            for i_j_k_fused_0 in T.thread_binding(4, thread="blockIdx.x"):
-                for i_j_k_fused_1 in T.thread_binding(1024, 
thread="threadIdx.x"):
+            for ax0_fused_0 in T.thread_binding(4, thread="blockIdx.x"):
+                for ax0_fused_1 in T.thread_binding(1024, 
thread="threadIdx.x"):
                     with T.block("T_reshape"):
-                        vi = T.axis.spatial(1, 0)
-                        vj = T.axis.spatial(1, 0)
-                        vk = T.axis.spatial(4096, i_j_k_fused_0 * 1024 + 
i_j_k_fused_1)
-                        T.reads(A[0, vk % 4096 // 128, 0, vk % 128])
-                        T.writes(C[vi, vj, vk])
-                        C[vi, vj, vk] = A[0, vk % 4096 // 128, 0, vk % 128]
+                        v0 = T.axis.spatial(4096, ax0_fused_0 * 1024 + 
ax0_fused_1)
+                        T.reads(A[0, v0 // 128, 0, v0 % 128])
+                        T.writes(C[0, 0, v0])
+                        C[0, 0, v0] = A[0, v0 // 128, 0, v0 % 128]
 
     target = Target("nvidia/geforce-rtx-3090-ti")
     with target:
diff --git a/tests/python/dlight/test_gpu_reduction.py 
b/tests/python/dlight/test_gpu_reduction.py
index 99307093c8..ce3a05e771 100644
--- a/tests/python/dlight/test_gpu_reduction.py
+++ b/tests/python/dlight/test_gpu_reduction.py
@@ -93,44 +93,41 @@ def test_softmax():
             # with T.block("root"):
             T_softmax_maxelem_shared = T.alloc_buffer((T.int64(1), 
T.int64(32), n), scope="shared")
             T_softmax_expsum_shared = T.alloc_buffer((T.int64(1), T.int64(32), 
n), scope="shared")
-            for i0_i1_i2_fused in T.thread_binding(n * T.int64(32), 
thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 256, 
"pragma_unroll_explicit": 1}):
-                for ax0, ax1, ax2, ax3_fused_0 in T.grid(T.int64(1), 
T.int64(1), T.int64(1), (m + T.int64(255)) // T.int64(256)):
-                    for ax3_fused_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
+            for ax0_ax1_fused in T.thread_binding(n * T.int64(32), 
thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 256, 
"pragma_unroll_explicit": 1}):
+                for ax0, ax1, ax2_fused_0 in T.grid(T.int64(1), T.int64(1), (m 
+ T.int64(255)) // T.int64(256)):
+                    for ax2_fused_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
                         with T.block("T_softmax_maxelem"):
-                            v_i0 = T.axis.spatial(T.int64(1), ax0)
-                            v_i1 = T.axis.spatial(T.int64(32), i0_i1_i2_fused 
// n + ax1)
-                            v_i2 = T.axis.spatial(n, i0_i1_i2_fused % n + ax2)
-                            v_k = T.axis.reduce(m, ax3_fused_0 * T.int64(256) 
+ ax3_fused_1)
-                            T.where(T.int64(0) <= i0_i1_i2_fused // n and 
i0_i1_i2_fused // n < T.int64(32) and T.int64(0) <= i0_i1_i2_fused % n and 
i0_i1_i2_fused % n < n and ax3_fused_0 * T.int64(256) + ax3_fused_1 < m)
-                            T.reads(lv44[v_i0, v_i1, v_i2, v_k])
-                            T.writes(T_softmax_maxelem_shared[v_i0, v_i1, 
v_i2])
+                            v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // 
n + ax0)
+                            v1 = T.axis.spatial(n, ax0_ax1_fused % n + ax1)
+                            v2 = T.axis.reduce(m, ax2_fused_0 * T.int64(256) + 
ax2_fused_1)
+                            T.where(T.int64(0) <= ax0_ax1_fused // n and 
ax0_ax1_fused // n < T.int64(32) and T.int64(0) <= ax0_ax1_fused % n and 
ax0_ax1_fused % n < n and ax2_fused_0 * T.int64(256) + ax2_fused_1 < m)
+                            T.reads(lv44[T.int64(0), v0, v1, v2])
+                            T.writes(T_softmax_maxelem_shared[T.int64(0), v0, 
v1])
                             with T.init():
-                                T_softmax_maxelem_shared[v_i0, v_i1, v_i2] = 
T.float32(-3.4028234663852886e+38)
-                            T_softmax_maxelem_shared[v_i0, v_i1, v_i2] = 
T.max(T_softmax_maxelem_shared[v_i0, v_i1, v_i2], lv44[v_i0, v_i1, v_i2, v_k])
-                for ax0, ax1, ax2, ax3_fused_0 in T.grid(T.int64(1), 
T.int64(1), T.int64(1), (m + T.int64(255)) // T.int64(256)):
-                    for ax3_fused_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
+                                T_softmax_maxelem_shared[T.int64(0), v0, v1] = 
T.float32(-3.4028234663852886e+38)
+                            T_softmax_maxelem_shared[T.int64(0), v0, v1] = 
T.max(T_softmax_maxelem_shared[T.int64(0), v0, v1], lv44[T.int64(0), v0, v1, 
v2])
+                for ax0, ax1, ax2_fused_0 in T.grid(T.int64(1), T.int64(1), (m 
+ T.int64(255)) // T.int64(256)):
+                    for ax2_fused_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
                         with T.block("T_softmax_expsum"):
-                            v_i0 = T.axis.spatial(T.int64(1), ax0)
-                            v_i1 = T.axis.spatial(T.int64(32), i0_i1_i2_fused 
// n + ax1)
-                            v_i2 = T.axis.spatial(n, i0_i1_i2_fused % n + ax2)
-                            v_k = T.axis.reduce(m, ax3_fused_0 * T.int64(256) 
+ ax3_fused_1)
-                            T.where(T.int64(0) <= i0_i1_i2_fused // n and 
i0_i1_i2_fused // n < T.int64(32) and T.int64(0) <= i0_i1_i2_fused % n and 
i0_i1_i2_fused % n < n and ax3_fused_0 * T.int64(256) + ax3_fused_1 < m)
-                            T.reads(lv44[v_i0, v_i1, v_i2, v_k], 
T_softmax_maxelem_shared[v_i0, v_i1, v_i2])
-                            T.writes(T_softmax_expsum_shared[v_i0, v_i1, v_i2])
+                            v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // 
n + ax0)
+                            v1 = T.axis.spatial(n, ax0_ax1_fused % n + ax1)
+                            v2 = T.axis.reduce(m, ax2_fused_0 * T.int64(256) + 
ax2_fused_1)
+                            T.where(T.int64(0) <= ax0_ax1_fused // n and 
ax0_ax1_fused // n < T.int64(32) and T.int64(0) <= ax0_ax1_fused % n and 
ax0_ax1_fused % n < n and ax2_fused_0 * T.int64(256) + ax2_fused_1 < m)
+                            T.reads(lv44[T.int64(0), v0, v1, v2], 
T_softmax_maxelem_shared[T.int64(0), v0, v1])
+                            T.writes(T_softmax_expsum_shared[T.int64(0), v0, 
v1])
                             with T.init():
-                                T_softmax_expsum_shared[v_i0, v_i1, v_i2] = 
T.float32(0)
-                            T_softmax_expsum_shared[v_i0, v_i1, v_i2] = 
T_softmax_expsum_shared[v_i0, v_i1, v_i2] + T.exp(lv44[v_i0, v_i1, v_i2, v_k] - 
T_softmax_maxelem_shared[v_i0, v_i1, v_i2])
-                for i3_0 in range((m + T.int64(255)) // T.int64(256)):
-                    for i3_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
+                                T_softmax_expsum_shared[T.int64(0), v0, v1] = 
T.float32(0)
+                            T_softmax_expsum_shared[T.int64(0), v0, v1] = 
T_softmax_expsum_shared[T.int64(0), v0, v1] + T.exp(lv44[T.int64(0), v0, v1, 
v2] - T_softmax_maxelem_shared[T.int64(0), v0, v1])
+                for ax2_0 in range((m + T.int64(255)) // T.int64(256)):
+                    for ax2_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
                         with T.block("compute"):
-                            v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
-                            v_i1 = T.axis.spatial(T.int64(32), i0_i1_i2_fused 
// n)
-                            v_i2 = T.axis.spatial(n, i0_i1_i2_fused % n)
-                            v_i3 = T.axis.spatial(m, i3_0 * T.int64(256) + 
i3_1)
-                            T.where(i3_0 * T.int64(256) + i3_1 < m)
-                            T.reads(lv44[v_i0, v_i1, v_i2, v_i3], 
T_softmax_maxelem_shared[v_i0, v_i1, v_i2], T_softmax_expsum_shared[v_i0, v_i1, 
v_i2])
-                            T.writes(var_compute_intermediate[v_i0, v_i1, 
v_i2, v_i3])
-                            var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = 
T.Cast("float16", T.exp(lv44[v_i0, v_i1, v_i2, v_i3] - 
T_softmax_maxelem_shared[v_i0, v_i1, v_i2]) / T_softmax_expsum_shared[v_i0, 
v_i1, v_i2])
+                            v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // 
n)
+                            v1 = T.axis.spatial(n, ax0_ax1_fused % n)
+                            v2 = T.axis.spatial(m, ax2_0 * T.int64(256) + 
ax2_1)
+                            T.where(ax2_0 * T.int64(256) + ax2_1 < m)
+                            T.reads(lv44[T.int64(0), v0, v1, v2], 
T_softmax_maxelem_shared[T.int64(0), v0, v1], 
T_softmax_expsum_shared[T.int64(0), v0, v1])
+                            T.writes(var_compute_intermediate[T.int64(0), v0, 
v1, v2])
+                            var_compute_intermediate[T.int64(0), v0, v1, v2] = 
T.Cast("float16", T.exp(lv44[T.int64(0), v0, v1, v2] - 
T_softmax_maxelem_shared[T.int64(0), v0, v1]) / 
T_softmax_expsum_shared[T.int64(0), v0, v1])
     # fmt: on
     _check(Before, After)
 
@@ -185,31 +182,29 @@ def test_layer_norm():
             # with T.block("root"):
             A_red_temp_v0_shared = T.alloc_buffer((T.int64(1), n), 
scope="shared")
             A_red_temp_v1_shared = T.alloc_buffer((T.int64(1), n), 
scope="shared")
-            for i0_i1_fused in T.thread_binding(n, thread="blockIdx.x", 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
-                for ax0, ax1, ax2_fused_0 in T.grid(T.int64(1), T.int64(1), 
T.int64(10)):
-                    for ax2_fused_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
+            for ax0_fused in T.thread_binding(n, thread="blockIdx.x", 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                for ax0, ax1_fused_0 in T.grid(T.int64(1), T.int64(10)):
+                    for ax1_fused_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
                         with T.block("A_red_temp"):
-                            v_ax0 = T.axis.spatial(T.int64(1), ax0)
-                            v_ax1 = T.axis.spatial(n, i0_i1_fused + ax1)
-                            v_k2 = T.axis.reduce(T.int64(2560), ax2_fused_0 * 
T.int64(256) + ax2_fused_1)
-                            T.reads(lv6[v_ax0, v_ax1, v_k2])
-                            T.writes(A_red_temp_v0_shared[v_ax0, v_ax1], 
A_red_temp_v1_shared[v_ax0, v_ax1])
+                            v0 = T.axis.spatial(n, ax0_fused + ax0)
+                            v1 = T.axis.reduce(T.int64(2560), ax1_fused_0 * 
T.int64(256) + ax1_fused_1)
+                            T.reads(lv6[T.int64(0), v0, v1])
+                            T.writes(A_red_temp_v0_shared[T.int64(0), v0], 
A_red_temp_v1_shared[T.int64(0), v0])
                             with T.init():
-                                A_red_temp_v0_shared[v_ax0, v_ax1] = 
T.float32(0)
-                                A_red_temp_v1_shared[v_ax0, v_ax1] = 
T.float32(0)
-                            v_A_red_temp_v0: T.float32 = 
A_red_temp_v0_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2]
-                            v_A_red_temp_v1: T.float32 = 
A_red_temp_v1_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0, 
v_ax1, v_k2]
-                            A_red_temp_v0_shared[v_ax0, v_ax1] = 
v_A_red_temp_v0
-                            A_red_temp_v1_shared[v_ax0, v_ax1] = 
v_A_red_temp_v1
-                for i2_0 in range(T.int64(10)):
-                    for i2_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
+                                A_red_temp_v0_shared[T.int64(0), v0] = 
T.float32(0)
+                                A_red_temp_v1_shared[T.int64(0), v0] = 
T.float32(0)
+                            v_A_red_temp_v0: T.float32 = 
A_red_temp_v0_shared[T.int64(0), v0] + lv6[T.int64(0), v0, v1]
+                            v_A_red_temp_v1: T.float32 = 
A_red_temp_v1_shared[T.int64(0), v0] + lv6[T.int64(0), v0, v1] * 
lv6[T.int64(0), v0, v1]
+                            A_red_temp_v0_shared[T.int64(0), v0] = 
v_A_red_temp_v0
+                            A_red_temp_v1_shared[T.int64(0), v0] = 
v_A_red_temp_v1
+                for ax1_0 in range(T.int64(10)):
+                    for ax1_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
                         with T.block("compute"):
-                            v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
-                            v_i1 = T.axis.spatial(n, i0_i1_fused)
-                            v_i2 = T.axis.spatial(T.int64(2560), i2_0 * 
T.int64(256) + i2_1)
-                            T.reads(lv6[v_i0, v_i1, v_i2], 
A_red_temp_v0_shared[v_i0, v_i1], A_red_temp_v1_shared[v_i0, v_i1], 
weight1[v_i2], bias[v_i2])
-                            T.writes(var_compute_intermediate[v_i0, v_i1, 
v_i2])
-                            var_compute_intermediate[v_i0, v_i1, v_i2] = 
T.Cast("float16", (lv6[v_i0, v_i1, v_i2] - A_red_temp_v0_shared[v_i0, v_i1] * 
T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1_shared[v_i0, v_i1] * 
T.float32(0.00039062500000000002) - A_red_temp_v0_shared[v_i0, v_i1] * 
T.float32(0.00039062500000000002) * (A_red_temp_v0_shared[v_i0, v_i1] * 
T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * 
weight1[v_i2] + bias[v_i2])
+                            v0 = T.axis.spatial(n, ax0_fused)
+                            v1 = T.axis.spatial(T.int64(2560), ax1_0 * 
T.int64(256) + ax1_1)
+                            T.reads(lv6[T.int64(0), v0, v1], 
A_red_temp_v0_shared[T.int64(0), v0], A_red_temp_v1_shared[T.int64(0), v0], 
weight1[v1], bias[v1])
+                            T.writes(var_compute_intermediate[T.int64(0), v0, 
v1])
+                            var_compute_intermediate[T.int64(0), v0, v1] = 
T.Cast("float16", (lv6[T.int64(0), v0, v1] - A_red_temp_v0_shared[T.int64(0), 
v0] * T.float32(0.00039062500000000002)) * 
T.rsqrt(A_red_temp_v1_shared[T.int64(0), v0] * 
T.float32(0.00039062500000000002) - A_red_temp_v0_shared[T.int64(0), v0] * 
T.float32(0.00039062500000000002) * (A_red_temp_v0_shared[T.int64(0), v0] * 
T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * 
weight1[v1] + bias[v1])
     # fmt: on
     _check(Before, After)
 
@@ -251,27 +246,25 @@ def test_rms_norm():
             rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), n, 
T.int64(4096)), "float16")
             # with T.block("root"):
             Ared_temp_shared = T.alloc_buffer((T.int64(1), n), scope="shared")
-            for bsz_i_fused in T.thread_binding(n, thread="blockIdx.x", 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
-                for ax0, ax1, ax2_fused_0 in T.grid(T.int64(1), T.int64(1), 
T.int64(16)):
-                    for ax2_fused_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
+            for ax0_fused in T.thread_binding(n, thread="blockIdx.x", 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                for ax0, ax1_fused_0 in T.grid(T.int64(1), T.int64(16)):
+                    for ax1_fused_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
                         with T.block("Ared_temp"):
-                            v_bsz = T.axis.spatial(T.int64(1), ax0)
-                            v_i = T.axis.spatial(n, bsz_i_fused + ax1)
-                            v_k = T.axis.reduce(T.int64(4096), ax2_fused_0 * 
T.int64(256) + ax2_fused_1)
-                            T.reads(A[v_bsz, v_i, v_k])
-                            T.writes(Ared_temp_shared[v_bsz, v_i])
+                            v0 = T.axis.spatial(n, ax0_fused + ax0)
+                            v1 = T.axis.reduce(T.int64(4096), ax1_fused_0 * 
T.int64(256) + ax1_fused_1)
+                            T.reads(A[T.int64(0), v0, v1])
+                            T.writes(Ared_temp_shared[T.int64(0), v0])
                             with T.init():
-                                Ared_temp_shared[v_bsz, v_i] = T.float32(0)
-                            Ared_temp_shared[v_bsz, v_i] = 
Ared_temp_shared[v_bsz, v_i] + T.Cast("float32", A[v_bsz, v_i, v_k]) * 
T.Cast("float32", A[v_bsz, v_i, v_k])
-                for k_0 in range(T.int64(16)):
-                    for k_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
+                                Ared_temp_shared[T.int64(0), v0] = T.float32(0)
+                            Ared_temp_shared[T.int64(0), v0] = 
Ared_temp_shared[T.int64(0), v0] + T.Cast("float32", A[T.int64(0), v0, v1]) * 
T.Cast("float32", A[T.int64(0), v0, v1])
+                for ax1_0 in range(T.int64(16)):
+                    for ax1_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
                         with T.block("rms_norm"):
-                            v_bsz = T.axis.spatial(T.int64(1), T.int64(0))
-                            v_i = T.axis.spatial(n, bsz_i_fused)
-                            v_k = T.axis.spatial(T.int64(4096), k_0 * 
T.int64(256) + k_1)
-                            T.reads(B[v_k], A[v_bsz, v_i, v_k], 
Ared_temp_shared[v_bsz, v_i])
-                            T.writes(rms_norm_1[v_bsz, v_i, v_k])
-                            rms_norm_1[v_bsz, v_i, v_k] = T.Cast("float16", 
T.Cast("float32", B[v_k]) * (T.Cast("float32", A[v_bsz, v_i, v_k]) / 
T.sqrt(Ared_temp_shared[v_bsz, v_i] * T.float32(0.000244140625) + 
T.float32(9.9999999999999995e-07))))
+                            v0 = T.axis.spatial(n, ax0_fused)
+                            v1 = T.axis.spatial(T.int64(4096), ax1_0 * 
T.int64(256) + ax1_1)
+                            T.reads(B[v1], A[T.int64(0), v0, v1], 
Ared_temp_shared[T.int64(0), v0])
+                            T.writes(rms_norm_1[T.int64(0), v0, v1])
+                            rms_norm_1[T.int64(0), v0, v1] = T.Cast("float16", 
T.Cast("float32", B[v1]) * (T.Cast("float32", A[T.int64(0), v0, v1]) / 
T.sqrt(Ared_temp_shared[T.int64(0), v0] * T.float32(0.000244140625) + 
T.float32(9.9999999999999995e-07))))
     # fmt: on
     _check(Before, After)
 
diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py 
b/tests/python/unittest/test_tir_schedule_compute_inline.py
index 8d90189507..b9779c34cf 100644
--- a/tests/python/unittest/test_tir_schedule_compute_inline.py
+++ b/tests/python/unittest/test_tir_schedule_compute_inline.py
@@ -694,6 +694,23 @@ def elementwise_producer_not_cover_consumer(
             D[vi, vj] = T.if_then_else(vi >= 128, B[vi - 128, vj], 
T.float32(0), dtype="float32")
 
 
[email protected]_func
+def elementwise_producer_is_reduction(
+    A: T.Buffer((128, 128), "float32"), D: T.Buffer((128), "float32")
+) -> None:
+    B = T.alloc_buffer((128))
+    for i, j in T.grid(128, 128):
+        with T.block("B"):
+            vi, vj = T.axis.remap("SR", [i, j])
+            with T.init():
+                B[vi] = T.float32(0)
+            B[vi] = B[vi] + A[vi, vj]
+    for i in T.grid(128):
+        with T.block("C"):
+            vi = T.axis.remap("S", [i])
+            D[vi] = B[vi] + 1.0
+
+
 @T.prim_func
 def elementwise_predicate_producer(a: T.handle, c: T.handle) -> None:
     A = T.match_buffer(a, (128, 128))
@@ -1267,6 +1284,13 @@ def 
test_reverse_compute_inline_producer_predicate_disallowed():
     )
 
 
+def test_reverse_compute_inline_producer_is_reduction():
+    """Test reverse comput inline when producer is reduction"""
+    sch = tir.Schedule(elementwise_producer_is_reduction, debug_mask="all")
+    with pytest.raises(tvm.tir.ScheduleError):
+        sch.reverse_compute_inline(sch.get_block("C"))
+
+
 def test_compute_inline_softmax():
     # fmt: off
     @T.prim_func

Reply via email to