This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 04f22a9257 [Dlight] Enhance Decode-GEMV Schedule (#15195)
04f22a9257 is described below

commit 04f22a92575b659b17d318470a00603d7faf9714
Author: Junru Shao <[email protected]>
AuthorDate: Tue Jul 4 21:47:32 2023 -0700

    [Dlight] Enhance Decode-GEMV Schedule (#15195)
    
    [Dlight] Enhance Decode-GEMV Rules
    
    This PR enhances Decode-GEMV rule with the following changes:
    - Normalize the GEMV iter domain to S-R-C via transform-block-layout.
      This would help with further analysis and scheduling, in cases for
      example, when there was no spatial loop in the original reduction
      block.
    - Get rid of the ad hoc iter type analysis, including the logic calling
      into a TVM packed func `tir.schedule.GetLoopIterType` using
      `tvm._ffi.get_global_func`.
    - Split out the logic for two separate cases of scheduling, where the
      innermost dimension is spatial or reduction.
    - Introduces `suggest_threads_per_block` to guess the threads to be
      allocated each threadblock. This helps avoid the previous case where
      dlight allocates 256 threads for a workload whose degree of parallelism
      is only 128.
    - Misc improvements.
    
    This rest of the changes are split out to separate PRs that are already
    merged to main.
---
 pyproject.toml                              |   4 +
 python/tvm/dlight/base/analysis.py          |   3 +-
 python/tvm/dlight/gpu/__init__.py           |   4 +-
 python/tvm/dlight/gpu/decode_gemv.py        | 257 ++++++++++++++++------------
 python/tvm/dlight/gpu/fallback.py           |   5 +-
 python/tvm/dlight/gpu/matmul.py             |   4 +-
 python/tvm/dlight/gpu/utils.py              |  87 ++++++++++
 tests/python/dlight/test_gpu_decode_gemv.py |  53 +++---
 8 files changed, 275 insertions(+), 142 deletions(-)

diff --git a/pyproject.toml b/pyproject.toml
index 5cca711ddb..e984b41b11 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -14,6 +14,10 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+[tool.isort]
+profile = "black"
+src_paths = ["python", "tests/python"]
+
 
 [tool.black]
 line-length = 100
diff --git a/python/tvm/dlight/base/analysis.py 
b/python/tvm/dlight/base/analysis.py
index d11e29a8ad..2607968ef2 100644
--- a/python/tvm/dlight/base/analysis.py
+++ b/python/tvm/dlight/base/analysis.py
@@ -17,13 +17,12 @@
 """Analysis on TIR blocks, loops and functions."""
 from typing import List, Optional, Union
 
-from typing_extensions import Literal
-
 from tvm import tir
 from tvm._ffi import get_global_func
 from tvm.target.target import Target
 from tvm.tir import Schedule
 from tvm.tir.schedule import BlockRV
+from typing_extensions import Literal
 
 
 class IterInfo:
diff --git a/python/tvm/dlight/gpu/__init__.py 
b/python/tvm/dlight/gpu/__init__.py
index 79090d400b..934928ffaf 100644
--- a/python/tvm/dlight/gpu/__init__.py
+++ b/python/tvm/dlight/gpu/__init__.py
@@ -18,7 +18,7 @@
 GPU-generic schedule rules.
 For CUDA/ROCm/Vulkan/Metal-specific rules, use 
`tvm.dlight.cuda/rocm/vulkan/metal` instead
 """
-from .fallback import Fallback
 from .decode_gemv import DecodeGEMV
-from .reduction import Reduction
+from .fallback import Fallback
 from .matmul import Matmul
+from .reduction import Reduction
diff --git a/python/tvm/dlight/gpu/decode_gemv.py 
b/python/tvm/dlight/gpu/decode_gemv.py
index 18395b8063..b9e8b44ef2 100644
--- a/python/tvm/dlight/gpu/decode_gemv.py
+++ b/python/tvm/dlight/gpu/decode_gemv.py
@@ -14,19 +14,20 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=missing-docstring
-"""A fallback schedule rule for GPU operators."""
-# pylint: disable=invalid-name
+"""A rule for DecodeGEMV."""
+from typing import List, Optional, Set, Tuple, Union
 
-from typing import List, Optional, Union
-
-from tvm import tir
-from tvm._ffi import get_global_func
-from tvm.arith import normalize_to_iter_sum
+from tvm import arith, tir
 from tvm.ir import structural_equal
 from tvm.target import Target
 
-from ..base import ScheduleRule, normalize_prim_func, 
try_inline_contiguous_spatial
+from ..base import (
+    BlockInfo,
+    ScheduleRule,
+    normalize_prim_func,
+    try_inline_contiguous_spatial,
+)
+from . import utils
 
 
 def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]:
@@ -47,13 +48,13 @@ def _get_reduction_expr(block: tir.Block) -> 
Optional[tir.PrimExpr]:
 
 def _detect_dominant_read(block: tir.Block) -> tir.PrimExpr:
     dominant_read, read_iters = None, None
-    tir_vars = set()
+    tir_vars: Set[tir.Var] = set()
     for buffer_region in block.reads:
         tir_vars.clear()
 
-        def _collect_tir_var(e):
-            if isinstance(e, tir.Var):
-                tir_vars.add(e)
+        def _collect_tir_var(expr):
+            if isinstance(expr, tir.Var):
+                tir_vars.add(expr)
 
         for expr in buffer_region.region:
             assert expr.extent == 1
@@ -68,11 +69,9 @@ def _detect_dominant_read(block: tir.Block) -> tir.PrimExpr:
 
 
 class DecodeGEMV(ScheduleRule):
-    def __init__(self) -> None:
-        super().__init__()
-        self.get_loop_iter_type = 
get_global_func("tir.schedule.GetLoopIterType")
+    """A rule for DecodeGEMV."""
 
-    def apply(  # pylint: disable=too-many-locals
+    def apply(  # pylint: 
disable=too-many-locals,too-many-branches,too-many-return-statements
         self,
         func: tir.PrimFunc,
         target: Target,
@@ -80,15 +79,8 @@ class DecodeGEMV(ScheduleRule):
     ) -> Union[None, tir.Schedule, List[tir.Schedule]]:
         if not isinstance(func, tir.PrimFunc):
             return None
-
-        if target.kind.name == "cuda":
-            len_tx, len_ty = 16, 16
-        else:
-            len_tx, len_ty = 8, 8
-
         sch = tir.Schedule(func)
         block_infos = try_inline_contiguous_spatial(sch, 
normalize_prim_func(sch))
-
         if block_infos is None or len(block_infos) > 2:
             return None
 
@@ -97,96 +89,145 @@ class DecodeGEMV(ScheduleRule):
         block_stmt = sch.get(block)
 
         # Step 1. Check reduction block
-        if not block_info.is_reduction():
+        if (
+            (not block_info.is_reduction())
+            or len(block_stmt.writes) != 1
+            or _get_reduction_expr(block_stmt) is None
+        ):
             return None
-        if len(block_stmt.writes) != 1:
-            return None
-        if _get_reduction_expr(block_stmt) is None:
-            return None
-
-        # Step 2. Sort out the spatial and reduction loops
-        sorted_iter_access = normalize_to_iter_sum(
-            _detect_dominant_read(block_stmt),
-            input_iters={i.var: i.dom for i in block_stmt.iter_vars},
+        # Step 2. Normalize the block, merge spatial and reduction iters
+        is_inner_reduction, c_factor = self._normalize(
+            sch,
+            block_info,
+            arith.normalize_to_iter_sum(
+                _detect_dominant_read(block_stmt),
+                input_iters={i.var: i.dom for i in block_stmt.iter_vars},
+            ),
         )
-        if sorted_iter_access.base != 0:
-            return None
-        iter_to_info = {i.var: i for i in block_info.iters}
-        s_loops, r_loops, c_loops = [], [], []
-        for split in sorted_iter_access.args:
-            block_var = split.source.source
-            block_var_info = iter_to_info[block_var]
-            loop_rv = block_var_info.loop_rv
-            is_inner_reduction = block_var_info.kind == "R"
-            if split.lower_factor > 1:
-                c_loop_factor = split.lower_factor
-                loop_rv, c_loop = sch.split(loop_rv, factors=[None, 
c_loop_factor])
-                c_loops.append(c_loop)
-                is_loop_c_reduction = is_inner_reduction
-            if is_inner_reduction:
-                r_loops.append(loop_rv)
-            else:
-                s_loops.append(loop_rv)
-
-        if len(c_loops) > 1:
-            return None
-        if len(s_loops) != len([_ for i in block_info.iters if i.kind == "S"]):
+        if is_inner_reduction is None and c_factor is None:
             return None
-        if len(s_loops) == 0 or len(r_loops) == 0:
-            return None
-
-        sch.reorder(*s_loops, *r_loops, *c_loops)
-        s = sch.fuse(*s_loops)
-        r = sch.fuse(*r_loops)
-
-        if is_inner_reduction:
-            _, tx = sch.split(r, factors=[None, len_tx * len_ty])
-            rf = sch.rfactor(tx, 0)
-            s, r, tx = sch.get_loops(rf)[:3]
-            sch.reorder(s, tx, r)
-            sch.reverse_compute_at(block, s, preserve_unit_loops=True)
-            sch.bind(tx, "threadIdx.x")
-            sch.bind(s, "blockIdx.x")
-        else:
-            sch.split(s, factors=[None, len_tx])
-            _, ty = sch.split(r, factors=[None, len_ty])
-            rf = sch.rfactor(ty, 0)
-            bx, tx, r, ty = sch.get_loops(rf)[:4]
-            sch.reorder(bx, tx, ty, r)
-            sch.reverse_compute_at(block, bx, preserve_unit_loops=True)
-            sch.bind(tx, "threadIdx.x")
-            sch.bind(ty, "threadIdx.y")
-            sch.bind(bx, "blockIdx.x")
-
-        s_loops, r_loops = [], []
-        for loop_rv in sch.get_loops(block)[1:]:
-            iter_type = self.get_loop_iter_type(sch, loop_rv)
-            if iter_type == "S":
-                s_loops.append(loop_rv)
-            elif iter_type == "R":
-                r_loops.append(loop_rv)
-            else:
-                raise RuntimeError("Unknown loop type " + str(iter_type))
-        sch.reorder(*s_loops, *r_loops)
-        s_ctr = sch.fuse(*s_loops)
-        r_ctr = sch.fuse(*r_loops)
-
-        if c_loops and not is_loop_c_reduction:
-            s_ctr, inner = sch.split(s_ctr, factors=[None, c_loop_factor])
-            sch.reorder(s_ctr, r_ctr, inner)
-
+        # Step 3. Do the scheduling
         if is_inner_reduction:
-            sch.bind(r_ctr, "threadIdx.x")
-            sch.set_scope(rf, 0, "local")
-            sch.decompose_reduction(rf, sch.get_loops(rf)[2])
+            self._sch_inner_reduction(sch, target, block, c_factor)
         else:
-            sch.bind(s_ctr, "threadIdx.x")
-            sch.bind(r_ctr, "threadIdx.y")
-            sch.set_scope(rf, 0, "local")
-            sch.decompose_reduction(rf, sch.get_loops(rf)[3])
-
+            self._sch_inner_spatial(sch, target, block, c_factor)
+        # Step 4. Schedule epilogue
         if len(block_infos) == 2:
             sch.set_scope(block, 0, "local")
             sch.reverse_compute_at(block_infos[1].block_rv, 
sch.get_loops(block)[0])
-
         return sch
+
+    def _normalize(
+        self,
+        sch: tir.Schedule,
+        block_info: BlockInfo,
+        iter_sum: arith.IterSumExpr,
+    ) -> Tuple[Optional[bool], Optional[int]]:
+        if iter_sum.base != 0:
+            return None, None
+        iter_to_info = {i.var: i for i in block_info.iters}
+        s_dom, r_dom, c_dom, c_factor = None, None, None, None
+        for split in iter_sum.args:
+            var = split.source.source
+            info = iter_to_info[var]
+            dom = info.dom
+            is_inner_reduction = info.kind == "R"
+            if split.lower_factor > 1:
+                if c_dom is not None:
+                    return None, None
+                c_dom = tir.floormod(var, split.lower_factor)
+                var = tir.floordiv(var, split.lower_factor)
+                dom = tir.floordiv(dom, split.lower_factor)
+                if not is_inner_reduction:
+                    c_factor = split.lower_factor
+            if is_inner_reduction:
+                if r_dom is None:
+                    r_dom = var
+                else:
+                    r_dom = r_dom * dom + var
+            else:
+                if s_dom is None:
+                    s_dom = var
+                else:
+                    s_dom = s_dom * dom + var
+
+        assert r_dom is not None
+        if s_dom is None:
+            s_dom = tir.const(1, r_dom.dtype)
+        if c_dom is None:
+            c_dom = tir.const(1, r_dom.dtype)
+        sch.transform_block_layout(
+            block_info.block_rv,
+            tir.IndexMap(
+                [i.var for i in block_info.iters],
+                [s_dom, r_dom, c_dom],
+                None,
+            ),
+        )
+        return is_inner_reduction, c_factor
+
+    def _sch_inner_reduction(
+        self,
+        sch: tir.Schedule,
+        target: Target,
+        block: tir.schedule.BlockRV,
+        unroll_spatial_factor: Optional[int],
+    ):
+        # pylint: disable=invalid-name
+        _, r, _ = sch.get_loops(block)
+        (len_tx,) = utils.suggest_threads_per_block(  # pylint: 
disable=unbalanced-tuple-unpacking
+            target, [sch.get(r)]
+        )
+
+        _, tx = sch.split(r, factors=[None, len_tx])
+        # Schedule the RF block
+        rf = sch.rfactor(tx, 0)
+        bx, r, tx, _ = sch.get_loops(rf)
+        sch.reorder(bx, tx, r)
+        sch.bind(bx, "blockIdx.x")
+        sch.bind(tx, "threadIdx.x")
+        sch.set_scope(rf, 0, "local")
+        sch.decompose_reduction(rf, r)
+        # Schedule the write back block
+        sch.reverse_compute_at(block, bx, preserve_unit_loops=True)
+        _, tx, *s = sch.get_loops(block)
+        s = sch.fuse(*s)
+        sch.reorder(s, tx)
+        if unroll_spatial_factor:
+            s, inner = sch.split(s, factors=[None, unroll_spatial_factor])
+            sch.reorder(s, tx, inner)
+        sch.bind(tx, "threadIdx.x")
+        # pylint: enable=invalid-name
+
+    def _sch_inner_spatial(
+        self,
+        sch: tir.Schedule,
+        _: Target,
+        block: tir.schedule.BlockRV,
+        unroll_spatial_factor: Optional[int],
+    ):
+        # pylint: disable=invalid-name
+        s, r, _ = sch.get_loops(block)
+        len_tx, len_ty = 16, 16
+        _, _ = sch.split(s, factors=[None, len_tx])
+        _, ty = sch.split(r, factors=[None, len_ty])
+        # Schedule the RF block
+        rf = sch.rfactor(ty, 0)
+        bx, tx, r, ty, _ = sch.get_loops(rf)
+        sch.reorder(bx, tx, ty, r)
+        sch.bind(tx, "threadIdx.x")
+        sch.bind(ty, "threadIdx.y")
+        sch.bind(bx, "blockIdx.x")
+        sch.set_scope(rf, 0, "local")
+        sch.decompose_reduction(rf, r)
+        # Schedule the write back block
+        sch.reverse_compute_at(block, bx, preserve_unit_loops=True)
+        _, r, *s = sch.get_loops(block)
+        s = sch.fuse(*s)
+        sch.reorder(s, r)
+        if unroll_spatial_factor:
+            s, inner = sch.split(s, factors=[None, unroll_spatial_factor])
+            sch.reorder(s, r, inner)
+        sch.bind(s, "threadIdx.x")
+        sch.bind(r, "threadIdx.y")
+        # pylint: enable=invalid-name
diff --git a/python/tvm/dlight/gpu/fallback.py 
b/python/tvm/dlight/gpu/fallback.py
index 6b120b1648..14b74887af 100644
--- a/python/tvm/dlight/gpu/fallback.py
+++ b/python/tvm/dlight/gpu/fallback.py
@@ -21,7 +21,8 @@ from typing import List
 from tvm import tir
 from tvm.target import Target
 
-from ..base import ScheduleRule, analysis, normalize_prim_func, try_inline
+from ..base import ScheduleRule, normalize_prim_func, try_inline
+from . import utils
 
 
 class Fallback(ScheduleRule):
@@ -36,7 +37,7 @@ class Fallback(ScheduleRule):
         target: Target,
         _: bool,
     ) -> tir.Schedule:
-        max_threads_per_block = analysis.get_max_threads_per_block(target)
+        max_threads_per_block = utils.max_threads_per_block(target)
 
         sch = tir.Schedule(func)
         block_infos = try_inline(sch, normalize_prim_func(sch))
diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py
index e66eaa3222..86d685e53c 100644
--- a/python/tvm/dlight/gpu/matmul.py
+++ b/python/tvm/dlight/gpu/matmul.py
@@ -16,14 +16,14 @@
 # under the License.
 # pylint: disable=missing-docstring, invalid-name
 """A GEMM schedule rule for GPU operators."""
-from enum import Enum
 from dataclasses import dataclass
+from enum import Enum
 from typing import Dict, List, Optional, Set, Tuple
 
 from tvm import tir
 from tvm.ir import Range
 from tvm.target import Target
-from tvm.tir import PrimExpr, Var, IterVar
+from tvm.tir import IterVar, PrimExpr, Var
 from tvm.tir.analysis import undefined_vars
 from tvm.tir.schedule.schedule import BlockRV
 
diff --git a/python/tvm/dlight/gpu/utils.py b/python/tvm/dlight/gpu/utils.py
new file mode 100644
index 0000000000..4fcc762942
--- /dev/null
+++ b/python/tvm/dlight/gpu/utils.py
@@ -0,0 +1,87 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""Utility methods for generic GPU."""
+from typing import List, Optional
+
+from tvm import tir
+from tvm.target import Target
+
+
+def max_threads_per_block(target: Target) -> int:
+    """Get the maximum number of threads per block for a given target.
+
+    Parameters
+    ----------
+    target : Target
+        The target to get the maximum number of threads per block for.
+
+    Returns
+    -------
+    max_threads_per_block : int
+        The maximum number of threads per block for the given target.
+    """
+    for name in ["max_threads_per_block", "max_num_threads"]:
+        result = target.attrs.get(name, None)
+        if result is not None:
+            return result
+    if target.kind.name == "cuda":
+        return 1024
+    return 256
+
+
+def suggest_threads_per_block(
+    target: Target,
+    loops: List[tir.For],
+    max_threads_for_dynamic_loop: int = 32,
+) -> List[int]:
+    if target.kind.name == "cuda":
+        threads = 256
+    else:
+        threads = 64
+    results: List[Optional[int]] = []
+    dynamic: List[int] = []
+    for i, loop in enumerate(loops):
+        loop_extent = loop.extent
+        if isinstance(loop_extent, tir.IntImm):
+            loop_extent = loop_extent.value
+            extent = 1
+            while extent <= loop_extent and extent <= threads:
+                extent *= 2
+            extent //= 2
+            assert extent >= 1
+            assert threads % extent == 0
+            threads //= extent
+            results.append(extent)
+        else:
+            results.append(None)
+            dynamic.append(i)
+
+    for i in dynamic:
+        extent = 1
+        while extent <= max_threads_for_dynamic_loop and extent <= threads:
+            extent *= 2
+        extent //= 2
+        assert extent >= 1
+        assert threads % extent == 0
+        threads //= extent
+        results[i] = extent
+
+    if dynamic:
+        results[dynamic[0]] *= threads
+
+    return results
diff --git a/tests/python/dlight/test_gpu_decode_gemv.py 
b/tests/python/dlight/test_gpu_decode_gemv.py
index 46232a461e..303b16809e 100644
--- a/tests/python/dlight/test_gpu_decode_gemv.py
+++ b/tests/python/dlight/test_gpu_decode_gemv.py
@@ -14,8 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=missing-docstring
-import tvm
+# pylint: 
disable=missing-docstring,line-too-long,invalid-name,too-few-public-methods,too-many-locals
 from tvm import dlight as dl
 from tvm.ir import assert_structural_equal
 from tvm.script import ir as I
@@ -31,7 +30,6 @@ def test_decode_gemv_1():
         @T.prim_func
         def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), 
"float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), 
"float16")):
             T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
-            # with T.block("root"):
             B = T.alloc_buffer((4096, 4096), "float16")
             for i, j in T.grid(4096, 4096):
                 with T.block("decode"):
@@ -66,8 +64,8 @@ def test_decode_gemv_1():
                         with T.block("matmul_rf_update"):
                             vk_0_fused_1 = T.axis.spatial(256, k_0_fused_1)
                             v_i2, vk_0_fused_0, vk_1 = T.axis.remap("SRR", 
[i2_i0_i1_fused, k_0_fused_0, k_1])
-                            C_rf_local[vk_0_fused_1, 0, 0, v_i2] = 
C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, vk_0_fused_0 * 2048 + 
vk_0_fused_1 * 8 + vk_1] * ((T.Cast("float16", 
T.bitwise_and(T.shift_right(W[v_i2, (vk_0_fused_0 * 2048 + vk_0_fused_1 * 8 + 
vk_1) // 8], T.Cast("uint32", (vk_0_fused_0 * 2048 + vk_0_fused_1 * 8 + vk_1) % 
8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i2, (vk_0_fused_0 * 
2048 + vk_0_fused_1 * 8 + vk_1) // 32])
-                for ax1_ax2_ax3_fused in range(1):
+                            C_rf_local[vk_0_fused_1, 0, 0, v_i2] = 
C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, (vk_0_fused_0 * 256 + 
vk_0_fused_1) * 8 + vk_1] * ((T.Cast("float16", 
T.bitwise_and(T.shift_right(W[v_i2, vk_0_fused_0 * 256 + vk_0_fused_1], 
T.Cast("uint32", ((vk_0_fused_0 * 256 + vk_0_fused_1) * 8 + vk_1) % 8) * 
T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i2, (vk_0_fused_0 * 256 + 
vk_0_fused_1) // 4])
+                for ax1_ax2_ax3_fused in range(1): # pylint: 
disable=unused-variable
                     for ax0_fused in T.thread_binding(256, 
thread="threadIdx.x"):
                         with T.block("matmul"):
                             vk_0_fused_1 = T.axis.reduce(256, ax0_fused)
@@ -128,7 +126,7 @@ def test_decode_gemv_2():
                                 vk_0_fused_1 = T.axis.spatial(16, k_0_fused_1)
                                 v_i2 = T.axis.spatial(4096, i2_i0_i1_fused_0 * 
16 + i2_i0_i1_fused_1)
                                 vk_0_fused_0, vk_1 = T.axis.remap("RR", 
[k_0_fused_0, k_1])
-                                C_rf_local[vk_0_fused_1, 0, 0, v_i2] = 
C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, vk_0_fused_0 * 128 + 
vk_0_fused_1 * 8 + vk_1] * ((T.Cast("float16", 
T.bitwise_and(T.shift_right(W[(vk_0_fused_0 * 128 + vk_0_fused_1 * 8 + vk_1) // 
8, v_i2], T.Cast("uint32", (vk_0_fused_0 * 128 + vk_0_fused_1 * 8 + vk_1) % 8) 
* T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[(vk_0_fused_0 * 128 + 
vk_0_fused_1 * 8 + vk_1) // 32, v_i2])
+                                C_rf_local[vk_0_fused_1, 0, 0, v_i2] = 
C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, (vk_0_fused_0 * 16 + 
vk_0_fused_1) * 8 + vk_1] * ((T.Cast("float16", 
T.bitwise_and(T.shift_right(W[vk_0_fused_0 * 16 + vk_0_fused_1, v_i2], 
T.Cast("uint32", ((vk_0_fused_0 * 16 + vk_0_fused_1) * 8 + vk_1) % 8) * 
T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[(vk_0_fused_0 * 16 + 
vk_0_fused_1) // 4, v_i2])
                 for ax1_ax2_ax3_fused in T.thread_binding(16, 
thread="threadIdx.x"):
                     for ax0_fused in T.thread_binding(16, 
thread="threadIdx.y"):
                         with T.block("matmul"):
@@ -184,23 +182,26 @@ def test_decode_gemv_3():
                     for i2_1_init in range(8):
                         with T.block("matmul_rf_init"):
                             vk_fused_1 = T.axis.spatial(256, k_fused_1)
-                            v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused * 8 + 
i2_1_init)
-                            C_rf_local[vk_fused_1, 0, 0, v_i2] = T.float16(0)
+                            v_i1 = T.axis.spatial(512, i2_0_i0_i1_fused)
+                            v_i2 = T.axis.spatial(8, i2_1_init)
+                            C_rf_local[vk_fused_1, 0, 0, v_i1 * 8 + v_i2] = 
T.float16(0)
                     for k_fused_0, i2_1 in T.grid(16, 8):
                         with T.block("matmul_rf_update"):
                             vk_fused_1 = T.axis.spatial(256, k_fused_1)
-                            v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused * 8 + 
i2_1)
+                            v_i1 = T.axis.spatial(512, i2_0_i0_i1_fused)
                             vk_fused_0 = T.axis.reduce(16, k_fused_0)
-                            C_rf_local[vk_fused_1, 0, 0, v_i2] = 
C_rf_local[vk_fused_1, 0, 0, v_i2] + V[0, 0, vk_fused_0 * 256 + vk_fused_1] * 
((T.Cast("float16", T.bitwise_and(T.shift_right(W[v_i2 // 8, vk_fused_0 * 256 + 
vk_fused_1], T.Cast("uint32", v_i2 % 8) * T.uint32(4)), T.uint32(15))) - 
T.float16(7)) * S[v_i2 // 32, vk_fused_0 * 256 + vk_fused_1])
+                            v_i2 = T.axis.spatial(8, i2_1)
+                            C_rf_local[vk_fused_1, 0, 0, v_i1 * 8 + v_i2] = 
C_rf_local[vk_fused_1, 0, 0, v_i1 * 8 + v_i2] + V[0, 0, vk_fused_0 * 256 + 
vk_fused_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(W[v_i1, 
vk_fused_0 * 256 + vk_fused_1], T.Cast("uint32", (v_i1 * 8 + v_i2) % 8) * 
T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i1 // 4, vk_fused_0 * 256 + 
vk_fused_1])
                 for ax1_ax2_ax3_fused_0 in range(1):
                     for ax0_fused in T.thread_binding(256, 
thread="threadIdx.x"):
                         for ax1_ax2_ax3_fused_1 in range(8):
                             with T.block("matmul"):
                                 vk_fused_1 = T.axis.reduce(256, ax0_fused)
-                                v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused * 
8 + ax1_ax2_ax3_fused_0 * 8 + ax1_ax2_ax3_fused_1)
+                                v_i1 = T.axis.spatial(512, i2_0_i0_i1_fused)
+                                v_i2 = T.axis.spatial(8, ax1_ax2_ax3_fused_0 * 
8 + ax1_ax2_ax3_fused_1)
                                 with T.init():
-                                    C[0, 0, v_i2] = T.float16(0)
-                                C[0, 0, v_i2] = C[0, 0, v_i2] + 
C_rf_local[vk_fused_1, 0, 0, v_i2]
+                                    C[0, 0, v_i1 * 8 + v_i2] = T.float16(0)
+                                C[0, 0, v_i1 * 8 + v_i2] = C[0, 0, v_i1 * 8 + 
v_i2] + C_rf_local[vk_fused_1, 0, 0, v_i1 * 8 + v_i2]
 
     # fmt: on
 
@@ -241,7 +242,6 @@ def test_decode_gemv_4():
         @T.prim_func
         def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), 
"float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), 
"float16")):
             T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, 
"tir.noalias": T.bool(True)})
-            # with T.block("root"):
             C_rf_local = T.alloc_buffer((16, 1, 1, 4096), "float16", 
scope="local")
             for i2_0_i0_i1_fused_0 in T.thread_binding(32, 
thread="blockIdx.x"):
                 for i2_0_i0_i1_fused_1 in T.thread_binding(16, 
thread="threadIdx.x"):
@@ -249,23 +249,26 @@ def test_decode_gemv_4():
                         for i2_1_init in range(8):
                             with T.block("matmul_rf_init"):
                                 vk_fused_1 = T.axis.spatial(16, k_fused_1)
-                                v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused_0 
* 128 + i2_0_i0_i1_fused_1 * 8 + i2_1_init)
-                                C_rf_local[vk_fused_1, 0, 0, v_i2] = 
T.float16(0)
+                                v1 = T.axis.spatial(512, i2_0_i0_i1_fused_0 * 
16 + i2_0_i0_i1_fused_1)
+                                v2 = T.axis.spatial(8, i2_1_init)
+                                C_rf_local[vk_fused_1, 0, 0, v1 * 8 + v2] = 
T.float16(0)
                         for k_fused_0, i2_1 in T.grid(256, 8):
                             with T.block("matmul_rf_update"):
                                 vk_fused_1 = T.axis.spatial(16, k_fused_1)
-                                v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused_0 
* 128 + i2_0_i0_i1_fused_1 * 8 + i2_1)
+                                v1 = T.axis.spatial(512, i2_0_i0_i1_fused_0 * 
16 + i2_0_i0_i1_fused_1)
                                 vk_fused_0 = T.axis.reduce(256, k_fused_0)
-                                C_rf_local[vk_fused_1, 0, 0, v_i2] = 
C_rf_local[vk_fused_1, 0, 0, v_i2] + V[0, 0, vk_fused_0 * 16 + vk_fused_1] * 
((T.Cast("float16", T.bitwise_and(T.shift_right(W[vk_fused_0 * 16 + vk_fused_1, 
v_i2 // 8], T.Cast("uint32", v_i2 % 8) * T.uint32(4)), T.uint32(15))) - 
T.float16(7)) * S[vk_fused_0 * 16 + vk_fused_1, v_i2 // 32])
+                                v2 = T.axis.spatial(8, i2_1)
+                                C_rf_local[vk_fused_1, 0, 0, v1 * 8 + v2] = 
C_rf_local[vk_fused_1, 0, 0, v1 * 8 + v2] + V[0, 0, vk_fused_0 * 16 + 
vk_fused_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(W[vk_fused_0 * 16 
+ vk_fused_1, v1], T.Cast("uint32", (v1 * 8 + v2) % 8) * T.uint32(4)), 
T.uint32(15))) - T.float16(7)) * S[vk_fused_0 * 16 + vk_fused_1, v1 // 4])
                 for ax1_ax2_ax3_fused_0 in T.thread_binding(16, 
thread="threadIdx.x"):
                     for ax0_fused in T.thread_binding(16, 
thread="threadIdx.y"):
                         for ax1_ax2_ax3_fused_1 in range(8):
                             with T.block("matmul"):
                                 vk_fused_1 = T.axis.reduce(16, ax0_fused)
-                                v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused_0 
* 128 + ax1_ax2_ax3_fused_0 * 8 + ax1_ax2_ax3_fused_1)
+                                v1 = T.axis.spatial(512, i2_0_i0_i1_fused_0 * 
16 + (ax1_ax2_ax3_fused_0 * 8 + ax1_ax2_ax3_fused_1) // 8)
+                                v2 = T.axis.spatial(8, (ax1_ax2_ax3_fused_0 * 
8 + ax1_ax2_ax3_fused_1) % 8)
                                 with T.init():
-                                    C[0, 0, v_i2] = T.float16(0)
-                                C[0, 0, v_i2] = C[0, 0, v_i2] + 
C_rf_local[vk_fused_1, 0, 0, v_i2]
+                                    C[0, 0, v1 * 8 + v2] = T.float16(0)
+                                C[0, 0, v1 * 8 + v2] = C[0, 0, v1 * 8 + v2] + 
C_rf_local[vk_fused_1, 0, 0, v1 * 8 + v2]
 
     # fmt: on
 
@@ -325,8 +328,8 @@ def test_decode_gemv_sigmoid():
                         with T.block("matmul_rf_update"):
                             vk_0_fused_1 = T.axis.spatial(256, k_0_fused_1)
                             v_i2, vk_0_fused_0, vk_1 = T.axis.remap("SRR", 
[i2_i0_i1_fused, k_0_fused_0, k_1])
-                            C_rf_local[vk_0_fused_1, 0, 0, v_i2] = 
C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, vk_0_fused_0 * 2048 + 
vk_0_fused_1 * 8 + vk_1] * ((T.Cast("float16", 
T.bitwise_and(T.shift_right(W[v_i2, (vk_0_fused_0 * 2048 + vk_0_fused_1 * 8 + 
vk_1) // 8], T.Cast("uint32", (vk_0_fused_0 * 2048 + vk_0_fused_1 * 8 + vk_1) % 
8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i2, (vk_0_fused_0 * 
2048 + vk_0_fused_1 * 8 + vk_1) // 32])
-                for ax1_ax2_ax3_fused in range(1):
+                            C_rf_local[vk_0_fused_1, 0, 0, v_i2] = 
C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, (vk_0_fused_0 * 256 + 
vk_0_fused_1) * 8 + vk_1] * ((T.Cast("float16", 
T.bitwise_and(T.shift_right(W[v_i2, vk_0_fused_0 * 256 + vk_0_fused_1], 
T.Cast("uint32", ((vk_0_fused_0 * 256 + vk_0_fused_1) * 8 + vk_1) % 8) * 
T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i2, (vk_0_fused_0 * 256 + 
vk_0_fused_1) // 4])
+                for ax1_ax2_ax3_fused in range(1): # pylint: 
disable=unused-variable
                     for ax0_fused in T.thread_binding(256, 
thread="threadIdx.x"):
                         with T.block("matmul"):
                             vk_0_fused_1 = T.axis.reduce(256, ax0_fused)
@@ -397,9 +400,7 @@ def test_decode_gemv_1_fp32():
                     for ax1_0_fused_0, ax1_1 in T.grid(2, 8):
                         with T.block("matmul_rf_update"):
                             vax1_0_fused_1, v0, vax1_0_fused_0, vax1_1 = 
T.axis.remap("SSRR", [ax1_0_fused_1, ax0_fused, ax1_0_fused_0, ax1_1])
-                            T.reads(C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0], 
V[0, 0, vax1_0_fused_0 * 2048 + vax1_0_fused_1 * 8 + vax1_1], W[v0, 
(vax1_0_fused_0 * 2048 + vax1_0_fused_1 * 8 + vax1_1) // 8], S[v0, 
(vax1_0_fused_0 * 2048 + vax1_0_fused_1 * 8 + vax1_1) // 32])
-                            T.writes(C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0])
-                            C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] = 
C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] + T.Cast("float32", V[0, 0, 
vax1_0_fused_0 * 2048 + vax1_0_fused_1 * 8 + vax1_1]) * T.Cast("float32", 
(T.Cast("float16", T.bitwise_and(T.shift_right(W[v0, (vax1_0_fused_0 * 2048 + 
vax1_0_fused_1 * 8 + vax1_1) // 8], T.Cast("uint32", (vax1_0_fused_0 * 2048 + 
vax1_0_fused_1 * 8 + vax1_1) % 8) * T.uint32(4)), T.uint32(15))) - 
T.float16(7)) * S[v0, (vax1_0_fused_0 * 2048 + vax1_0 [...]
+                            C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] = 
C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] + T.Cast("float32", V[0, 0, 
(vax1_0_fused_0 * 256 + vax1_0_fused_1) * 8 + vax1_1]) * T.Cast("float32", 
(T.Cast("float16", T.bitwise_and(T.shift_right(W[v0, vax1_0_fused_0 * 256 + 
vax1_0_fused_1], T.Cast("uint32", ((vax1_0_fused_0 * 256 + vax1_0_fused_1) * 8 
+ vax1_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v0, 
(vax1_0_fused_0 * 256 + vax1_0_fused_1) // 4])
                 for ax1_fused in range(1):
                     for ax0_fused_1 in T.thread_binding(256, 
thread="threadIdx.x"):
                         with T.block("matmul"):

Reply via email to