This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 04f22a9257 [Dlight] Enhance Decode-GEMV Schedule (#15195)
04f22a9257 is described below
commit 04f22a92575b659b17d318470a00603d7faf9714
Author: Junru Shao <[email protected]>
AuthorDate: Tue Jul 4 21:47:32 2023 -0700
[Dlight] Enhance Decode-GEMV Schedule (#15195)
[Dlight] Enhance Decode-GEMV Rules
This PR enhances Decode-GEMV rule with the following changes:
- Normalize the GEMV iter domain to S-R-C via transform-block-layout.
This would help with further analysis and scheduling, in cases for
example, when there was no spatial loop in the original reduction
block.
- Get rid of the ad hoc iter type analysis, including the logic calling
into a TVM packed func `tir.schedule.GetLoopIterType` using
`tvm._ffi.get_global_func`.
- Split out the logic for two separate cases of scheduling, where the
innermost dimension is spatial or reduction.
- Introduces `suggest_threads_per_block` to guess the threads to be
allocated each threadblock. This helps avoid the previous case where
dlight allocates 256 threads for a workload whose degree of parallelism
is only 128.
- Misc improvements.
This rest of the changes are split out to separate PRs that are already
merged to main.
---
pyproject.toml | 4 +
python/tvm/dlight/base/analysis.py | 3 +-
python/tvm/dlight/gpu/__init__.py | 4 +-
python/tvm/dlight/gpu/decode_gemv.py | 257 ++++++++++++++++------------
python/tvm/dlight/gpu/fallback.py | 5 +-
python/tvm/dlight/gpu/matmul.py | 4 +-
python/tvm/dlight/gpu/utils.py | 87 ++++++++++
tests/python/dlight/test_gpu_decode_gemv.py | 53 +++---
8 files changed, 275 insertions(+), 142 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 5cca711ddb..e984b41b11 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -14,6 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+[tool.isort]
+profile = "black"
+src_paths = ["python", "tests/python"]
+
[tool.black]
line-length = 100
diff --git a/python/tvm/dlight/base/analysis.py
b/python/tvm/dlight/base/analysis.py
index d11e29a8ad..2607968ef2 100644
--- a/python/tvm/dlight/base/analysis.py
+++ b/python/tvm/dlight/base/analysis.py
@@ -17,13 +17,12 @@
"""Analysis on TIR blocks, loops and functions."""
from typing import List, Optional, Union
-from typing_extensions import Literal
-
from tvm import tir
from tvm._ffi import get_global_func
from tvm.target.target import Target
from tvm.tir import Schedule
from tvm.tir.schedule import BlockRV
+from typing_extensions import Literal
class IterInfo:
diff --git a/python/tvm/dlight/gpu/__init__.py
b/python/tvm/dlight/gpu/__init__.py
index 79090d400b..934928ffaf 100644
--- a/python/tvm/dlight/gpu/__init__.py
+++ b/python/tvm/dlight/gpu/__init__.py
@@ -18,7 +18,7 @@
GPU-generic schedule rules.
For CUDA/ROCm/Vulkan/Metal-specific rules, use
`tvm.dlight.cuda/rocm/vulkan/metal` instead
"""
-from .fallback import Fallback
from .decode_gemv import DecodeGEMV
-from .reduction import Reduction
+from .fallback import Fallback
from .matmul import Matmul
+from .reduction import Reduction
diff --git a/python/tvm/dlight/gpu/decode_gemv.py
b/python/tvm/dlight/gpu/decode_gemv.py
index 18395b8063..b9e8b44ef2 100644
--- a/python/tvm/dlight/gpu/decode_gemv.py
+++ b/python/tvm/dlight/gpu/decode_gemv.py
@@ -14,19 +14,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=missing-docstring
-"""A fallback schedule rule for GPU operators."""
-# pylint: disable=invalid-name
+"""A rule for DecodeGEMV."""
+from typing import List, Optional, Set, Tuple, Union
-from typing import List, Optional, Union
-
-from tvm import tir
-from tvm._ffi import get_global_func
-from tvm.arith import normalize_to_iter_sum
+from tvm import arith, tir
from tvm.ir import structural_equal
from tvm.target import Target
-from ..base import ScheduleRule, normalize_prim_func,
try_inline_contiguous_spatial
+from ..base import (
+ BlockInfo,
+ ScheduleRule,
+ normalize_prim_func,
+ try_inline_contiguous_spatial,
+)
+from . import utils
def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]:
@@ -47,13 +48,13 @@ def _get_reduction_expr(block: tir.Block) ->
Optional[tir.PrimExpr]:
def _detect_dominant_read(block: tir.Block) -> tir.PrimExpr:
dominant_read, read_iters = None, None
- tir_vars = set()
+ tir_vars: Set[tir.Var] = set()
for buffer_region in block.reads:
tir_vars.clear()
- def _collect_tir_var(e):
- if isinstance(e, tir.Var):
- tir_vars.add(e)
+ def _collect_tir_var(expr):
+ if isinstance(expr, tir.Var):
+ tir_vars.add(expr)
for expr in buffer_region.region:
assert expr.extent == 1
@@ -68,11 +69,9 @@ def _detect_dominant_read(block: tir.Block) -> tir.PrimExpr:
class DecodeGEMV(ScheduleRule):
- def __init__(self) -> None:
- super().__init__()
- self.get_loop_iter_type =
get_global_func("tir.schedule.GetLoopIterType")
+ """A rule for DecodeGEMV."""
- def apply( # pylint: disable=too-many-locals
+ def apply( # pylint:
disable=too-many-locals,too-many-branches,too-many-return-statements
self,
func: tir.PrimFunc,
target: Target,
@@ -80,15 +79,8 @@ class DecodeGEMV(ScheduleRule):
) -> Union[None, tir.Schedule, List[tir.Schedule]]:
if not isinstance(func, tir.PrimFunc):
return None
-
- if target.kind.name == "cuda":
- len_tx, len_ty = 16, 16
- else:
- len_tx, len_ty = 8, 8
-
sch = tir.Schedule(func)
block_infos = try_inline_contiguous_spatial(sch,
normalize_prim_func(sch))
-
if block_infos is None or len(block_infos) > 2:
return None
@@ -97,96 +89,145 @@ class DecodeGEMV(ScheduleRule):
block_stmt = sch.get(block)
# Step 1. Check reduction block
- if not block_info.is_reduction():
+ if (
+ (not block_info.is_reduction())
+ or len(block_stmt.writes) != 1
+ or _get_reduction_expr(block_stmt) is None
+ ):
return None
- if len(block_stmt.writes) != 1:
- return None
- if _get_reduction_expr(block_stmt) is None:
- return None
-
- # Step 2. Sort out the spatial and reduction loops
- sorted_iter_access = normalize_to_iter_sum(
- _detect_dominant_read(block_stmt),
- input_iters={i.var: i.dom for i in block_stmt.iter_vars},
+ # Step 2. Normalize the block, merge spatial and reduction iters
+ is_inner_reduction, c_factor = self._normalize(
+ sch,
+ block_info,
+ arith.normalize_to_iter_sum(
+ _detect_dominant_read(block_stmt),
+ input_iters={i.var: i.dom for i in block_stmt.iter_vars},
+ ),
)
- if sorted_iter_access.base != 0:
- return None
- iter_to_info = {i.var: i for i in block_info.iters}
- s_loops, r_loops, c_loops = [], [], []
- for split in sorted_iter_access.args:
- block_var = split.source.source
- block_var_info = iter_to_info[block_var]
- loop_rv = block_var_info.loop_rv
- is_inner_reduction = block_var_info.kind == "R"
- if split.lower_factor > 1:
- c_loop_factor = split.lower_factor
- loop_rv, c_loop = sch.split(loop_rv, factors=[None,
c_loop_factor])
- c_loops.append(c_loop)
- is_loop_c_reduction = is_inner_reduction
- if is_inner_reduction:
- r_loops.append(loop_rv)
- else:
- s_loops.append(loop_rv)
-
- if len(c_loops) > 1:
- return None
- if len(s_loops) != len([_ for i in block_info.iters if i.kind == "S"]):
+ if is_inner_reduction is None and c_factor is None:
return None
- if len(s_loops) == 0 or len(r_loops) == 0:
- return None
-
- sch.reorder(*s_loops, *r_loops, *c_loops)
- s = sch.fuse(*s_loops)
- r = sch.fuse(*r_loops)
-
- if is_inner_reduction:
- _, tx = sch.split(r, factors=[None, len_tx * len_ty])
- rf = sch.rfactor(tx, 0)
- s, r, tx = sch.get_loops(rf)[:3]
- sch.reorder(s, tx, r)
- sch.reverse_compute_at(block, s, preserve_unit_loops=True)
- sch.bind(tx, "threadIdx.x")
- sch.bind(s, "blockIdx.x")
- else:
- sch.split(s, factors=[None, len_tx])
- _, ty = sch.split(r, factors=[None, len_ty])
- rf = sch.rfactor(ty, 0)
- bx, tx, r, ty = sch.get_loops(rf)[:4]
- sch.reorder(bx, tx, ty, r)
- sch.reverse_compute_at(block, bx, preserve_unit_loops=True)
- sch.bind(tx, "threadIdx.x")
- sch.bind(ty, "threadIdx.y")
- sch.bind(bx, "blockIdx.x")
-
- s_loops, r_loops = [], []
- for loop_rv in sch.get_loops(block)[1:]:
- iter_type = self.get_loop_iter_type(sch, loop_rv)
- if iter_type == "S":
- s_loops.append(loop_rv)
- elif iter_type == "R":
- r_loops.append(loop_rv)
- else:
- raise RuntimeError("Unknown loop type " + str(iter_type))
- sch.reorder(*s_loops, *r_loops)
- s_ctr = sch.fuse(*s_loops)
- r_ctr = sch.fuse(*r_loops)
-
- if c_loops and not is_loop_c_reduction:
- s_ctr, inner = sch.split(s_ctr, factors=[None, c_loop_factor])
- sch.reorder(s_ctr, r_ctr, inner)
-
+ # Step 3. Do the scheduling
if is_inner_reduction:
- sch.bind(r_ctr, "threadIdx.x")
- sch.set_scope(rf, 0, "local")
- sch.decompose_reduction(rf, sch.get_loops(rf)[2])
+ self._sch_inner_reduction(sch, target, block, c_factor)
else:
- sch.bind(s_ctr, "threadIdx.x")
- sch.bind(r_ctr, "threadIdx.y")
- sch.set_scope(rf, 0, "local")
- sch.decompose_reduction(rf, sch.get_loops(rf)[3])
-
+ self._sch_inner_spatial(sch, target, block, c_factor)
+ # Step 4. Schedule epilogue
if len(block_infos) == 2:
sch.set_scope(block, 0, "local")
sch.reverse_compute_at(block_infos[1].block_rv,
sch.get_loops(block)[0])
-
return sch
+
+ def _normalize(
+ self,
+ sch: tir.Schedule,
+ block_info: BlockInfo,
+ iter_sum: arith.IterSumExpr,
+ ) -> Tuple[Optional[bool], Optional[int]]:
+ if iter_sum.base != 0:
+ return None, None
+ iter_to_info = {i.var: i for i in block_info.iters}
+ s_dom, r_dom, c_dom, c_factor = None, None, None, None
+ for split in iter_sum.args:
+ var = split.source.source
+ info = iter_to_info[var]
+ dom = info.dom
+ is_inner_reduction = info.kind == "R"
+ if split.lower_factor > 1:
+ if c_dom is not None:
+ return None, None
+ c_dom = tir.floormod(var, split.lower_factor)
+ var = tir.floordiv(var, split.lower_factor)
+ dom = tir.floordiv(dom, split.lower_factor)
+ if not is_inner_reduction:
+ c_factor = split.lower_factor
+ if is_inner_reduction:
+ if r_dom is None:
+ r_dom = var
+ else:
+ r_dom = r_dom * dom + var
+ else:
+ if s_dom is None:
+ s_dom = var
+ else:
+ s_dom = s_dom * dom + var
+
+ assert r_dom is not None
+ if s_dom is None:
+ s_dom = tir.const(1, r_dom.dtype)
+ if c_dom is None:
+ c_dom = tir.const(1, r_dom.dtype)
+ sch.transform_block_layout(
+ block_info.block_rv,
+ tir.IndexMap(
+ [i.var for i in block_info.iters],
+ [s_dom, r_dom, c_dom],
+ None,
+ ),
+ )
+ return is_inner_reduction, c_factor
+
+ def _sch_inner_reduction(
+ self,
+ sch: tir.Schedule,
+ target: Target,
+ block: tir.schedule.BlockRV,
+ unroll_spatial_factor: Optional[int],
+ ):
+ # pylint: disable=invalid-name
+ _, r, _ = sch.get_loops(block)
+ (len_tx,) = utils.suggest_threads_per_block( # pylint:
disable=unbalanced-tuple-unpacking
+ target, [sch.get(r)]
+ )
+
+ _, tx = sch.split(r, factors=[None, len_tx])
+ # Schedule the RF block
+ rf = sch.rfactor(tx, 0)
+ bx, r, tx, _ = sch.get_loops(rf)
+ sch.reorder(bx, tx, r)
+ sch.bind(bx, "blockIdx.x")
+ sch.bind(tx, "threadIdx.x")
+ sch.set_scope(rf, 0, "local")
+ sch.decompose_reduction(rf, r)
+ # Schedule the write back block
+ sch.reverse_compute_at(block, bx, preserve_unit_loops=True)
+ _, tx, *s = sch.get_loops(block)
+ s = sch.fuse(*s)
+ sch.reorder(s, tx)
+ if unroll_spatial_factor:
+ s, inner = sch.split(s, factors=[None, unroll_spatial_factor])
+ sch.reorder(s, tx, inner)
+ sch.bind(tx, "threadIdx.x")
+ # pylint: enable=invalid-name
+
+ def _sch_inner_spatial(
+ self,
+ sch: tir.Schedule,
+ _: Target,
+ block: tir.schedule.BlockRV,
+ unroll_spatial_factor: Optional[int],
+ ):
+ # pylint: disable=invalid-name
+ s, r, _ = sch.get_loops(block)
+ len_tx, len_ty = 16, 16
+ _, _ = sch.split(s, factors=[None, len_tx])
+ _, ty = sch.split(r, factors=[None, len_ty])
+ # Schedule the RF block
+ rf = sch.rfactor(ty, 0)
+ bx, tx, r, ty, _ = sch.get_loops(rf)
+ sch.reorder(bx, tx, ty, r)
+ sch.bind(tx, "threadIdx.x")
+ sch.bind(ty, "threadIdx.y")
+ sch.bind(bx, "blockIdx.x")
+ sch.set_scope(rf, 0, "local")
+ sch.decompose_reduction(rf, r)
+ # Schedule the write back block
+ sch.reverse_compute_at(block, bx, preserve_unit_loops=True)
+ _, r, *s = sch.get_loops(block)
+ s = sch.fuse(*s)
+ sch.reorder(s, r)
+ if unroll_spatial_factor:
+ s, inner = sch.split(s, factors=[None, unroll_spatial_factor])
+ sch.reorder(s, r, inner)
+ sch.bind(s, "threadIdx.x")
+ sch.bind(r, "threadIdx.y")
+ # pylint: enable=invalid-name
diff --git a/python/tvm/dlight/gpu/fallback.py
b/python/tvm/dlight/gpu/fallback.py
index 6b120b1648..14b74887af 100644
--- a/python/tvm/dlight/gpu/fallback.py
+++ b/python/tvm/dlight/gpu/fallback.py
@@ -21,7 +21,8 @@ from typing import List
from tvm import tir
from tvm.target import Target
-from ..base import ScheduleRule, analysis, normalize_prim_func, try_inline
+from ..base import ScheduleRule, normalize_prim_func, try_inline
+from . import utils
class Fallback(ScheduleRule):
@@ -36,7 +37,7 @@ class Fallback(ScheduleRule):
target: Target,
_: bool,
) -> tir.Schedule:
- max_threads_per_block = analysis.get_max_threads_per_block(target)
+ max_threads_per_block = utils.max_threads_per_block(target)
sch = tir.Schedule(func)
block_infos = try_inline(sch, normalize_prim_func(sch))
diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py
index e66eaa3222..86d685e53c 100644
--- a/python/tvm/dlight/gpu/matmul.py
+++ b/python/tvm/dlight/gpu/matmul.py
@@ -16,14 +16,14 @@
# under the License.
# pylint: disable=missing-docstring, invalid-name
"""A GEMM schedule rule for GPU operators."""
-from enum import Enum
from dataclasses import dataclass
+from enum import Enum
from typing import Dict, List, Optional, Set, Tuple
from tvm import tir
from tvm.ir import Range
from tvm.target import Target
-from tvm.tir import PrimExpr, Var, IterVar
+from tvm.tir import IterVar, PrimExpr, Var
from tvm.tir.analysis import undefined_vars
from tvm.tir.schedule.schedule import BlockRV
diff --git a/python/tvm/dlight/gpu/utils.py b/python/tvm/dlight/gpu/utils.py
new file mode 100644
index 0000000000..4fcc762942
--- /dev/null
+++ b/python/tvm/dlight/gpu/utils.py
@@ -0,0 +1,87 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""Utility methods for generic GPU."""
+from typing import List, Optional
+
+from tvm import tir
+from tvm.target import Target
+
+
+def max_threads_per_block(target: Target) -> int:
+ """Get the maximum number of threads per block for a given target.
+
+ Parameters
+ ----------
+ target : Target
+ The target to get the maximum number of threads per block for.
+
+ Returns
+ -------
+ max_threads_per_block : int
+ The maximum number of threads per block for the given target.
+ """
+ for name in ["max_threads_per_block", "max_num_threads"]:
+ result = target.attrs.get(name, None)
+ if result is not None:
+ return result
+ if target.kind.name == "cuda":
+ return 1024
+ return 256
+
+
+def suggest_threads_per_block(
+ target: Target,
+ loops: List[tir.For],
+ max_threads_for_dynamic_loop: int = 32,
+) -> List[int]:
+ if target.kind.name == "cuda":
+ threads = 256
+ else:
+ threads = 64
+ results: List[Optional[int]] = []
+ dynamic: List[int] = []
+ for i, loop in enumerate(loops):
+ loop_extent = loop.extent
+ if isinstance(loop_extent, tir.IntImm):
+ loop_extent = loop_extent.value
+ extent = 1
+ while extent <= loop_extent and extent <= threads:
+ extent *= 2
+ extent //= 2
+ assert extent >= 1
+ assert threads % extent == 0
+ threads //= extent
+ results.append(extent)
+ else:
+ results.append(None)
+ dynamic.append(i)
+
+ for i in dynamic:
+ extent = 1
+ while extent <= max_threads_for_dynamic_loop and extent <= threads:
+ extent *= 2
+ extent //= 2
+ assert extent >= 1
+ assert threads % extent == 0
+ threads //= extent
+ results[i] = extent
+
+ if dynamic:
+ results[dynamic[0]] *= threads
+
+ return results
diff --git a/tests/python/dlight/test_gpu_decode_gemv.py
b/tests/python/dlight/test_gpu_decode_gemv.py
index 46232a461e..303b16809e 100644
--- a/tests/python/dlight/test_gpu_decode_gemv.py
+++ b/tests/python/dlight/test_gpu_decode_gemv.py
@@ -14,8 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=missing-docstring
-import tvm
+# pylint:
disable=missing-docstring,line-too-long,invalid-name,too-few-public-methods,too-many-locals
from tvm import dlight as dl
from tvm.ir import assert_structural_equal
from tvm.script import ir as I
@@ -31,7 +30,6 @@ def test_decode_gemv_1():
@T.prim_func
def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128),
"float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096),
"float16")):
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
- # with T.block("root"):
B = T.alloc_buffer((4096, 4096), "float16")
for i, j in T.grid(4096, 4096):
with T.block("decode"):
@@ -66,8 +64,8 @@ def test_decode_gemv_1():
with T.block("matmul_rf_update"):
vk_0_fused_1 = T.axis.spatial(256, k_0_fused_1)
v_i2, vk_0_fused_0, vk_1 = T.axis.remap("SRR",
[i2_i0_i1_fused, k_0_fused_0, k_1])
- C_rf_local[vk_0_fused_1, 0, 0, v_i2] =
C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, vk_0_fused_0 * 2048 +
vk_0_fused_1 * 8 + vk_1] * ((T.Cast("float16",
T.bitwise_and(T.shift_right(W[v_i2, (vk_0_fused_0 * 2048 + vk_0_fused_1 * 8 +
vk_1) // 8], T.Cast("uint32", (vk_0_fused_0 * 2048 + vk_0_fused_1 * 8 + vk_1) %
8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i2, (vk_0_fused_0 *
2048 + vk_0_fused_1 * 8 + vk_1) // 32])
- for ax1_ax2_ax3_fused in range(1):
+ C_rf_local[vk_0_fused_1, 0, 0, v_i2] =
C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, (vk_0_fused_0 * 256 +
vk_0_fused_1) * 8 + vk_1] * ((T.Cast("float16",
T.bitwise_and(T.shift_right(W[v_i2, vk_0_fused_0 * 256 + vk_0_fused_1],
T.Cast("uint32", ((vk_0_fused_0 * 256 + vk_0_fused_1) * 8 + vk_1) % 8) *
T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i2, (vk_0_fused_0 * 256 +
vk_0_fused_1) // 4])
+ for ax1_ax2_ax3_fused in range(1): # pylint:
disable=unused-variable
for ax0_fused in T.thread_binding(256,
thread="threadIdx.x"):
with T.block("matmul"):
vk_0_fused_1 = T.axis.reduce(256, ax0_fused)
@@ -128,7 +126,7 @@ def test_decode_gemv_2():
vk_0_fused_1 = T.axis.spatial(16, k_0_fused_1)
v_i2 = T.axis.spatial(4096, i2_i0_i1_fused_0 *
16 + i2_i0_i1_fused_1)
vk_0_fused_0, vk_1 = T.axis.remap("RR",
[k_0_fused_0, k_1])
- C_rf_local[vk_0_fused_1, 0, 0, v_i2] =
C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, vk_0_fused_0 * 128 +
vk_0_fused_1 * 8 + vk_1] * ((T.Cast("float16",
T.bitwise_and(T.shift_right(W[(vk_0_fused_0 * 128 + vk_0_fused_1 * 8 + vk_1) //
8, v_i2], T.Cast("uint32", (vk_0_fused_0 * 128 + vk_0_fused_1 * 8 + vk_1) % 8)
* T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[(vk_0_fused_0 * 128 +
vk_0_fused_1 * 8 + vk_1) // 32, v_i2])
+ C_rf_local[vk_0_fused_1, 0, 0, v_i2] =
C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, (vk_0_fused_0 * 16 +
vk_0_fused_1) * 8 + vk_1] * ((T.Cast("float16",
T.bitwise_and(T.shift_right(W[vk_0_fused_0 * 16 + vk_0_fused_1, v_i2],
T.Cast("uint32", ((vk_0_fused_0 * 16 + vk_0_fused_1) * 8 + vk_1) % 8) *
T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[(vk_0_fused_0 * 16 +
vk_0_fused_1) // 4, v_i2])
for ax1_ax2_ax3_fused in T.thread_binding(16,
thread="threadIdx.x"):
for ax0_fused in T.thread_binding(16,
thread="threadIdx.y"):
with T.block("matmul"):
@@ -184,23 +182,26 @@ def test_decode_gemv_3():
for i2_1_init in range(8):
with T.block("matmul_rf_init"):
vk_fused_1 = T.axis.spatial(256, k_fused_1)
- v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused * 8 +
i2_1_init)
- C_rf_local[vk_fused_1, 0, 0, v_i2] = T.float16(0)
+ v_i1 = T.axis.spatial(512, i2_0_i0_i1_fused)
+ v_i2 = T.axis.spatial(8, i2_1_init)
+ C_rf_local[vk_fused_1, 0, 0, v_i1 * 8 + v_i2] =
T.float16(0)
for k_fused_0, i2_1 in T.grid(16, 8):
with T.block("matmul_rf_update"):
vk_fused_1 = T.axis.spatial(256, k_fused_1)
- v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused * 8 +
i2_1)
+ v_i1 = T.axis.spatial(512, i2_0_i0_i1_fused)
vk_fused_0 = T.axis.reduce(16, k_fused_0)
- C_rf_local[vk_fused_1, 0, 0, v_i2] =
C_rf_local[vk_fused_1, 0, 0, v_i2] + V[0, 0, vk_fused_0 * 256 + vk_fused_1] *
((T.Cast("float16", T.bitwise_and(T.shift_right(W[v_i2 // 8, vk_fused_0 * 256 +
vk_fused_1], T.Cast("uint32", v_i2 % 8) * T.uint32(4)), T.uint32(15))) -
T.float16(7)) * S[v_i2 // 32, vk_fused_0 * 256 + vk_fused_1])
+ v_i2 = T.axis.spatial(8, i2_1)
+ C_rf_local[vk_fused_1, 0, 0, v_i1 * 8 + v_i2] =
C_rf_local[vk_fused_1, 0, 0, v_i1 * 8 + v_i2] + V[0, 0, vk_fused_0 * 256 +
vk_fused_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(W[v_i1,
vk_fused_0 * 256 + vk_fused_1], T.Cast("uint32", (v_i1 * 8 + v_i2) % 8) *
T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i1 // 4, vk_fused_0 * 256 +
vk_fused_1])
for ax1_ax2_ax3_fused_0 in range(1):
for ax0_fused in T.thread_binding(256,
thread="threadIdx.x"):
for ax1_ax2_ax3_fused_1 in range(8):
with T.block("matmul"):
vk_fused_1 = T.axis.reduce(256, ax0_fused)
- v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused *
8 + ax1_ax2_ax3_fused_0 * 8 + ax1_ax2_ax3_fused_1)
+ v_i1 = T.axis.spatial(512, i2_0_i0_i1_fused)
+ v_i2 = T.axis.spatial(8, ax1_ax2_ax3_fused_0 *
8 + ax1_ax2_ax3_fused_1)
with T.init():
- C[0, 0, v_i2] = T.float16(0)
- C[0, 0, v_i2] = C[0, 0, v_i2] +
C_rf_local[vk_fused_1, 0, 0, v_i2]
+ C[0, 0, v_i1 * 8 + v_i2] = T.float16(0)
+ C[0, 0, v_i1 * 8 + v_i2] = C[0, 0, v_i1 * 8 +
v_i2] + C_rf_local[vk_fused_1, 0, 0, v_i1 * 8 + v_i2]
# fmt: on
@@ -241,7 +242,6 @@ def test_decode_gemv_4():
@T.prim_func
def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128),
"float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096),
"float16")):
T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1,
"tir.noalias": T.bool(True)})
- # with T.block("root"):
C_rf_local = T.alloc_buffer((16, 1, 1, 4096), "float16",
scope="local")
for i2_0_i0_i1_fused_0 in T.thread_binding(32,
thread="blockIdx.x"):
for i2_0_i0_i1_fused_1 in T.thread_binding(16,
thread="threadIdx.x"):
@@ -249,23 +249,26 @@ def test_decode_gemv_4():
for i2_1_init in range(8):
with T.block("matmul_rf_init"):
vk_fused_1 = T.axis.spatial(16, k_fused_1)
- v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused_0
* 128 + i2_0_i0_i1_fused_1 * 8 + i2_1_init)
- C_rf_local[vk_fused_1, 0, 0, v_i2] =
T.float16(0)
+ v1 = T.axis.spatial(512, i2_0_i0_i1_fused_0 *
16 + i2_0_i0_i1_fused_1)
+ v2 = T.axis.spatial(8, i2_1_init)
+ C_rf_local[vk_fused_1, 0, 0, v1 * 8 + v2] =
T.float16(0)
for k_fused_0, i2_1 in T.grid(256, 8):
with T.block("matmul_rf_update"):
vk_fused_1 = T.axis.spatial(16, k_fused_1)
- v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused_0
* 128 + i2_0_i0_i1_fused_1 * 8 + i2_1)
+ v1 = T.axis.spatial(512, i2_0_i0_i1_fused_0 *
16 + i2_0_i0_i1_fused_1)
vk_fused_0 = T.axis.reduce(256, k_fused_0)
- C_rf_local[vk_fused_1, 0, 0, v_i2] =
C_rf_local[vk_fused_1, 0, 0, v_i2] + V[0, 0, vk_fused_0 * 16 + vk_fused_1] *
((T.Cast("float16", T.bitwise_and(T.shift_right(W[vk_fused_0 * 16 + vk_fused_1,
v_i2 // 8], T.Cast("uint32", v_i2 % 8) * T.uint32(4)), T.uint32(15))) -
T.float16(7)) * S[vk_fused_0 * 16 + vk_fused_1, v_i2 // 32])
+ v2 = T.axis.spatial(8, i2_1)
+ C_rf_local[vk_fused_1, 0, 0, v1 * 8 + v2] =
C_rf_local[vk_fused_1, 0, 0, v1 * 8 + v2] + V[0, 0, vk_fused_0 * 16 +
vk_fused_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(W[vk_fused_0 * 16
+ vk_fused_1, v1], T.Cast("uint32", (v1 * 8 + v2) % 8) * T.uint32(4)),
T.uint32(15))) - T.float16(7)) * S[vk_fused_0 * 16 + vk_fused_1, v1 // 4])
for ax1_ax2_ax3_fused_0 in T.thread_binding(16,
thread="threadIdx.x"):
for ax0_fused in T.thread_binding(16,
thread="threadIdx.y"):
for ax1_ax2_ax3_fused_1 in range(8):
with T.block("matmul"):
vk_fused_1 = T.axis.reduce(16, ax0_fused)
- v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused_0
* 128 + ax1_ax2_ax3_fused_0 * 8 + ax1_ax2_ax3_fused_1)
+ v1 = T.axis.spatial(512, i2_0_i0_i1_fused_0 *
16 + (ax1_ax2_ax3_fused_0 * 8 + ax1_ax2_ax3_fused_1) // 8)
+ v2 = T.axis.spatial(8, (ax1_ax2_ax3_fused_0 *
8 + ax1_ax2_ax3_fused_1) % 8)
with T.init():
- C[0, 0, v_i2] = T.float16(0)
- C[0, 0, v_i2] = C[0, 0, v_i2] +
C_rf_local[vk_fused_1, 0, 0, v_i2]
+ C[0, 0, v1 * 8 + v2] = T.float16(0)
+ C[0, 0, v1 * 8 + v2] = C[0, 0, v1 * 8 + v2] +
C_rf_local[vk_fused_1, 0, 0, v1 * 8 + v2]
# fmt: on
@@ -325,8 +328,8 @@ def test_decode_gemv_sigmoid():
with T.block("matmul_rf_update"):
vk_0_fused_1 = T.axis.spatial(256, k_0_fused_1)
v_i2, vk_0_fused_0, vk_1 = T.axis.remap("SRR",
[i2_i0_i1_fused, k_0_fused_0, k_1])
- C_rf_local[vk_0_fused_1, 0, 0, v_i2] =
C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, vk_0_fused_0 * 2048 +
vk_0_fused_1 * 8 + vk_1] * ((T.Cast("float16",
T.bitwise_and(T.shift_right(W[v_i2, (vk_0_fused_0 * 2048 + vk_0_fused_1 * 8 +
vk_1) // 8], T.Cast("uint32", (vk_0_fused_0 * 2048 + vk_0_fused_1 * 8 + vk_1) %
8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i2, (vk_0_fused_0 *
2048 + vk_0_fused_1 * 8 + vk_1) // 32])
- for ax1_ax2_ax3_fused in range(1):
+ C_rf_local[vk_0_fused_1, 0, 0, v_i2] =
C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, (vk_0_fused_0 * 256 +
vk_0_fused_1) * 8 + vk_1] * ((T.Cast("float16",
T.bitwise_and(T.shift_right(W[v_i2, vk_0_fused_0 * 256 + vk_0_fused_1],
T.Cast("uint32", ((vk_0_fused_0 * 256 + vk_0_fused_1) * 8 + vk_1) % 8) *
T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i2, (vk_0_fused_0 * 256 +
vk_0_fused_1) // 4])
+ for ax1_ax2_ax3_fused in range(1): # pylint:
disable=unused-variable
for ax0_fused in T.thread_binding(256,
thread="threadIdx.x"):
with T.block("matmul"):
vk_0_fused_1 = T.axis.reduce(256, ax0_fused)
@@ -397,9 +400,7 @@ def test_decode_gemv_1_fp32():
for ax1_0_fused_0, ax1_1 in T.grid(2, 8):
with T.block("matmul_rf_update"):
vax1_0_fused_1, v0, vax1_0_fused_0, vax1_1 =
T.axis.remap("SSRR", [ax1_0_fused_1, ax0_fused, ax1_0_fused_0, ax1_1])
- T.reads(C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0],
V[0, 0, vax1_0_fused_0 * 2048 + vax1_0_fused_1 * 8 + vax1_1], W[v0,
(vax1_0_fused_0 * 2048 + vax1_0_fused_1 * 8 + vax1_1) // 8], S[v0,
(vax1_0_fused_0 * 2048 + vax1_0_fused_1 * 8 + vax1_1) // 32])
- T.writes(C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0])
- C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] =
C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] + T.Cast("float32", V[0, 0,
vax1_0_fused_0 * 2048 + vax1_0_fused_1 * 8 + vax1_1]) * T.Cast("float32",
(T.Cast("float16", T.bitwise_and(T.shift_right(W[v0, (vax1_0_fused_0 * 2048 +
vax1_0_fused_1 * 8 + vax1_1) // 8], T.Cast("uint32", (vax1_0_fused_0 * 2048 +
vax1_0_fused_1 * 8 + vax1_1) % 8) * T.uint32(4)), T.uint32(15))) -
T.float16(7)) * S[v0, (vax1_0_fused_0 * 2048 + vax1_0 [...]
+ C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] =
C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] + T.Cast("float32", V[0, 0,
(vax1_0_fused_0 * 256 + vax1_0_fused_1) * 8 + vax1_1]) * T.Cast("float32",
(T.Cast("float16", T.bitwise_and(T.shift_right(W[v0, vax1_0_fused_0 * 256 +
vax1_0_fused_1], T.Cast("uint32", ((vax1_0_fused_0 * 256 + vax1_0_fused_1) * 8
+ vax1_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v0,
(vax1_0_fused_0 * 256 + vax1_0_fused_1) // 4])
for ax1_fused in range(1):
for ax0_fused_1 in T.thread_binding(256,
thread="threadIdx.x"):
with T.block("matmul"):