This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 741ca41814 [Unity][Dlight] general reduction rule for gemv-decode
(#15169)
741ca41814 is described below
commit 741ca418143c3a53c607dfe88b4e49696aecaa85
Author: Bohan Hou <[email protected]>
AuthorDate: Sat Jul 1 22:18:27 2023 -0700
[Unity][Dlight] general reduction rule for gemv-decode (#15169)
This PR introduces a general rule for gemv-decode
---------
Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Tianqi Chen <[email protected]>
---
python/tvm/dlight/__init__.py | 10 +-
python/tvm/dlight/base/__init__.py | 4 +-
python/tvm/dlight/base/analysis.py | 135 +++++--
python/tvm/dlight/base/common_schedules.py | 37 +-
python/tvm/dlight/gpu/__init__.py | 1 +
python/tvm/dlight/gpu/decode_gemv.py | 192 +++++++++
python/tvm/dlight/gpu/fallback.py | 16 +-
python/tvm/dlight/gpu/reduction.py | 38 +-
src/tir/schedule/analysis/analysis.cc | 15 +
src/tir/schedule/primitive/compute_inline.cc | 21 +-
src/tir/schedule/transform.cc | 52 +++
tests/python/dlight/test_gpu_decode_gemv.py | 432 +++++++++++++++++++++
tests/python/dlight/test_gpu_fallback.py | 15 +-
tests/python/dlight/test_gpu_reduction.py | 139 ++++---
.../unittest/test_tir_schedule_compute_inline.py | 24 ++
15 files changed, 970 insertions(+), 161 deletions(-)
diff --git a/python/tvm/dlight/__init__.py b/python/tvm/dlight/__init__.py
index 23dd17993c..421d4017d1 100644
--- a/python/tvm/dlight/__init__.py
+++ b/python/tvm/dlight/__init__.py
@@ -16,4 +16,12 @@
# under the License.
"""DLight package provides efficient schedules out-of-box for deep learning
workloads."""
from . import gpu
-from .base import ApplyDefaultSchedule, ScheduleRule
+from .base import (
+ ApplyDefaultSchedule,
+ BlockInfo,
+ IterInfo,
+ ScheduleRule,
+ normalize_prim_func,
+ try_inline,
+ try_inline_contiguous_spatial,
+)
diff --git a/python/tvm/dlight/base/__init__.py
b/python/tvm/dlight/base/__init__.py
index d14db6c4a7..b69c82fca0 100644
--- a/python/tvm/dlight/base/__init__.py
+++ b/python/tvm/dlight/base/__init__.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Base infra"""
-from .analysis import BlockInfo
-from .common_schedules import try_inline
+from .analysis import BlockInfo, IterInfo, normalize_prim_func
+from .common_schedules import try_inline, try_inline_contiguous_spatial
from .schedule_rule import ScheduleRule
from .transform import ApplyDefaultSchedule
diff --git a/python/tvm/dlight/base/analysis.py
b/python/tvm/dlight/base/analysis.py
index a2508c87ba..a5d70c6c0c 100644
--- a/python/tvm/dlight/base/analysis.py
+++ b/python/tvm/dlight/base/analysis.py
@@ -15,61 +15,134 @@
# specific language governing permissions and limitations
# under the License.
"""Analysis on TIR blocks, loops and functions."""
-from typing import List, Union
+from typing import List, Optional, Union
+
+from typing_extensions import Literal
from tvm import tir
+from tvm._ffi import get_global_func
+
+
+class IterInfo:
+ """Information about a loop/iter var."""
+
+ kind: Literal["S", "R", "O"]
+ var: tir.Var
+ _dom: tir.PrimExpr
+ loop_rv: tir.schedule.LoopRV
+
+ def __init__(
+ self,
+ kind: Literal["S", "R", "O"],
+ var: tir.Var,
+ dom: tir.PrimExpr,
+ loop_rv: tir.schedule.LoopRV,
+ ):
+ """Construct an IterInfo object."""
+ self.kind = kind
+ self.var = var
+ self._dom = dom
+ self.loop_rv = loop_rv
+
+ @property
+ def dom(self) -> Union[int, tir.PrimExpr]:
+ """The iteration domain of the loop."""
+ return int(self._dom) if isinstance(self._dom, tir.IntImm) else
self._dom
+
+ def __str__(self) -> str:
+ return f'Iter("{self.kind}", {self.dom})'
+
+ def __repr__(self) -> str:
+ return str(self)
class BlockInfo:
"""Information about a TIR block."""
- block: tir.schedule.BlockRV
- """The TIR block the current schedule refers to"""
name: str
- """The name of the block"""
- iters: List[tir.IterVar]
- """The iteration domains of the current block"""
+ iters: List[IterInfo]
+ block_rv: tir.schedule.BlockRV
+ _reduction_block: bool
def __init__(
self,
- sch: tir.Schedule,
- block: tir.schedule.BlockRV,
+ name: str,
+ iters: List[IterInfo],
+ block_rv: tir.schedule.BlockRV,
+ reduction_block: bool = False,
):
- """Construct a BlockInfo object via TIR schedule."""
- tir_block = sch.get(block)
- self.block = block
- self.name = tir_block.name_hint
- self.iters = list(tir_block.iter_vars)
+ """Construct a BlockInfo object."""
+ self.name = name
+ self.block_rv = block_rv
+ self.iters = iters
+ self._reduction_block = reduction_block
def dom(self) -> List[Union[int, tir.PrimExpr]]:
"""The iteration domain of the block."""
-
- def _iter_dom(i: tir.IterVar) -> Union[int, tir.PrimExpr]:
- result = i.dom.extent
- if isinstance(result, tir.IntImm):
- result = int(result)
- return result
-
- result = [_iter_dom(i) for i in self.iters]
- return result
+ return [i.dom for i in self.iters]
def dom_kind(self) -> str:
"""The iteration domain kind of the block, for example, SSSS, SSSR."""
+ return "".join(i.kind for i in self.iters)
- def _iter_kind(i: tir.IterVar) -> str:
- return {
- tir.IterVar.DataPar: "S",
- tir.IterVar.CommReduce: "R",
- }.get(i.iter_type, "O")
+ def is_injective(self) -> bool:
+ """Whether the block is injective, i.e. all its iteration domains are
injective."""
+ return all(k == "S" for k in self.dom_kind())
- return "".join(_iter_kind(i) for i in self.iters)
+ def is_reduction(self) -> bool:
+ """Whether the block is a reduction workload."""
+ # TODO(@junrushao): distinguish GEMV and reduction
+ return self._reduction_block
- def is_spatial(self) -> bool:
- """Whether the block is spatial, i.e. all its iteration domains are
spatial."""
- return all(k == "S" for k in self.dom_kind())
+ def is_gemv(self) -> bool:
+ """Whether the block is a GEMV workload."""
+ raise NotImplementedError
+
+ def is_gemm(self) -> bool:
+ """Whether the block is a GEMM workload."""
+ raise NotImplementedError
def __str__(self) -> str:
return f'BlockInfo("{self.name}", "{self.dom_kind()}", {self.dom()})'
def __repr__(self) -> str:
return str(self)
+
+
+_normalize_prim_func = get_global_func("tir.schedule.NormalizePrimFunc")
+
+
+def normalize_prim_func(sch: tir.Schedule) -> Optional[List[BlockInfo]]:
+ """Normalize the primfunc to normal form"""
+ try:
+ result = _normalize_prim_func(sch)
+ if result is None:
+ return None
+ except Exception: # pylint: disable=broad-except
+ return None
+
+ def _iter_kind(i: tir.IterVar) -> str:
+ return {
+ tir.IterVar.DataPar: "S",
+ tir.IterVar.CommReduce: "R",
+ }.get(i.iter_type, "O")
+
+ blocks: List[BlockInfo] = []
+ for block, loops, iters, is_reduction in zip(*result):
+ blocks.append(
+ BlockInfo(
+ name=sch.get(block).name_hint,
+ iters=[
+ IterInfo(
+ kind=_iter_kind(iter), # type: ignore
+ var=iter.var,
+ dom=iter.dom.extent,
+ loop_rv=loop,
+ )
+ for loop, iter in zip(loops, iters)
+ ],
+ block_rv=block,
+ reduction_block=is_reduction,
+ )
+ )
+ return blocks
diff --git a/python/tvm/dlight/base/common_schedules.py
b/python/tvm/dlight/base/common_schedules.py
index 6568f9e5b5..b91f46c3ca 100644
--- a/python/tvm/dlight/base/common_schedules.py
+++ b/python/tvm/dlight/base/common_schedules.py
@@ -44,7 +44,7 @@ def try_inline(
def _trial(func: Callable):
for i, block in enumerate(blocks):
try:
- func(block.block)
+ func(block.block_rv)
except: # pylint: disable=bare-except
continue
return i
@@ -58,3 +58,38 @@ def try_inline(
break
blocks.pop(i)
return blocks
+
+
+def try_inline_contiguous_spatial(sch: tir.Schedule, block_infos:
List[BlockInfo]):
+ """Try to spatial blocks in a schedule
+
+ Parameters
+ ----------
+ sch : tir.Schedule
+ The TIR schedule used to inline blocks.
+ block_infos : List[BlockInfo]
+ The blocks to be try.
+
+ Returns
+ -------
+ remaining : List[BlockInfo]
+ The remaining blocks that cannot be inlined.
+ """
+
+ if block_infos is None:
+ return None
+ results = []
+ spatial_blocks = []
+ block: BlockInfo
+ for block in block_infos:
+ if block.is_injective():
+ spatial_blocks.append(block)
+ elif spatial_blocks:
+ results.extend(try_inline(sch, spatial_blocks))
+ results.append(block)
+ spatial_blocks = []
+ else:
+ results.append(block)
+ if spatial_blocks:
+ results.extend(try_inline(sch, spatial_blocks))
+ return results
diff --git a/python/tvm/dlight/gpu/__init__.py
b/python/tvm/dlight/gpu/__init__.py
index 098f71d608..b689bef381 100644
--- a/python/tvm/dlight/gpu/__init__.py
+++ b/python/tvm/dlight/gpu/__init__.py
@@ -19,4 +19,5 @@ GPU-generic schedule rules.
For CUDA/ROCm/Vulkan/Metal-specific rules, use
`tvm.dlight.cuda/rocm/vulkan/metal` instead
"""
from .fallback import Fallback
+from .decode_gemv import DecodeGEMV
from .reduction import Reduction
diff --git a/python/tvm/dlight/gpu/decode_gemv.py
b/python/tvm/dlight/gpu/decode_gemv.py
new file mode 100644
index 0000000000..18395b8063
--- /dev/null
+++ b/python/tvm/dlight/gpu/decode_gemv.py
@@ -0,0 +1,192 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""A fallback schedule rule for GPU operators."""
+# pylint: disable=invalid-name
+
+from typing import List, Optional, Union
+
+from tvm import tir
+from tvm._ffi import get_global_func
+from tvm.arith import normalize_to_iter_sum
+from tvm.ir import structural_equal
+from tvm.target import Target
+
+from ..base import ScheduleRule, normalize_prim_func,
try_inline_contiguous_spatial
+
+
+def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]:
+ # Detect and return `Y` in `X[...] = X[...] + Y`
+ buffer_store = block.body
+ if not isinstance(buffer_store, tir.BufferStore):
+ return None
+ if not isinstance(buffer_store.value, tir.Add):
+ return None
+ if not structural_equal(
+ buffer_store.value.a,
+ tir.BufferLoad(buffer_store.buffer, block.body.indices),
+ map_free_vars=True,
+ ):
+ return None
+ return buffer_store.value.b
+
+
+def _detect_dominant_read(block: tir.Block) -> tir.PrimExpr:
+ dominant_read, read_iters = None, None
+ tir_vars = set()
+ for buffer_region in block.reads:
+ tir_vars.clear()
+
+ def _collect_tir_var(e):
+ if isinstance(e, tir.Var):
+ tir_vars.add(e)
+
+ for expr in buffer_region.region:
+ assert expr.extent == 1
+ tir.stmt_functor.post_order_visit(expr.min, _collect_tir_var)
+
+ if read_iters is None or read_iters < len(tir_vars):
+ read_iters = len(tir_vars)
+ dominant_read = buffer_region
+ assert dominant_read is not None
+ (result,) = dominant_read.buffer.offset_of([e.min for e in
dominant_read.region])
+ return result
+
+
+class DecodeGEMV(ScheduleRule):
+ def __init__(self) -> None:
+ super().__init__()
+ self.get_loop_iter_type =
get_global_func("tir.schedule.GetLoopIterType")
+
+ def apply( # pylint: disable=too-many-locals
+ self,
+ func: tir.PrimFunc,
+ target: Target,
+ _: bool,
+ ) -> Union[None, tir.Schedule, List[tir.Schedule]]:
+ if not isinstance(func, tir.PrimFunc):
+ return None
+
+ if target.kind.name == "cuda":
+ len_tx, len_ty = 16, 16
+ else:
+ len_tx, len_ty = 8, 8
+
+ sch = tir.Schedule(func)
+ block_infos = try_inline_contiguous_spatial(sch,
normalize_prim_func(sch))
+
+ if block_infos is None or len(block_infos) > 2:
+ return None
+
+ block_info = block_infos[0]
+ block = block_info.block_rv
+ block_stmt = sch.get(block)
+
+ # Step 1. Check reduction block
+ if not block_info.is_reduction():
+ return None
+ if len(block_stmt.writes) != 1:
+ return None
+ if _get_reduction_expr(block_stmt) is None:
+ return None
+
+ # Step 2. Sort out the spatial and reduction loops
+ sorted_iter_access = normalize_to_iter_sum(
+ _detect_dominant_read(block_stmt),
+ input_iters={i.var: i.dom for i in block_stmt.iter_vars},
+ )
+ if sorted_iter_access.base != 0:
+ return None
+ iter_to_info = {i.var: i for i in block_info.iters}
+ s_loops, r_loops, c_loops = [], [], []
+ for split in sorted_iter_access.args:
+ block_var = split.source.source
+ block_var_info = iter_to_info[block_var]
+ loop_rv = block_var_info.loop_rv
+ is_inner_reduction = block_var_info.kind == "R"
+ if split.lower_factor > 1:
+ c_loop_factor = split.lower_factor
+ loop_rv, c_loop = sch.split(loop_rv, factors=[None,
c_loop_factor])
+ c_loops.append(c_loop)
+ is_loop_c_reduction = is_inner_reduction
+ if is_inner_reduction:
+ r_loops.append(loop_rv)
+ else:
+ s_loops.append(loop_rv)
+
+ if len(c_loops) > 1:
+ return None
+ if len(s_loops) != len([_ for i in block_info.iters if i.kind == "S"]):
+ return None
+ if len(s_loops) == 0 or len(r_loops) == 0:
+ return None
+
+ sch.reorder(*s_loops, *r_loops, *c_loops)
+ s = sch.fuse(*s_loops)
+ r = sch.fuse(*r_loops)
+
+ if is_inner_reduction:
+ _, tx = sch.split(r, factors=[None, len_tx * len_ty])
+ rf = sch.rfactor(tx, 0)
+ s, r, tx = sch.get_loops(rf)[:3]
+ sch.reorder(s, tx, r)
+ sch.reverse_compute_at(block, s, preserve_unit_loops=True)
+ sch.bind(tx, "threadIdx.x")
+ sch.bind(s, "blockIdx.x")
+ else:
+ sch.split(s, factors=[None, len_tx])
+ _, ty = sch.split(r, factors=[None, len_ty])
+ rf = sch.rfactor(ty, 0)
+ bx, tx, r, ty = sch.get_loops(rf)[:4]
+ sch.reorder(bx, tx, ty, r)
+ sch.reverse_compute_at(block, bx, preserve_unit_loops=True)
+ sch.bind(tx, "threadIdx.x")
+ sch.bind(ty, "threadIdx.y")
+ sch.bind(bx, "blockIdx.x")
+
+ s_loops, r_loops = [], []
+ for loop_rv in sch.get_loops(block)[1:]:
+ iter_type = self.get_loop_iter_type(sch, loop_rv)
+ if iter_type == "S":
+ s_loops.append(loop_rv)
+ elif iter_type == "R":
+ r_loops.append(loop_rv)
+ else:
+ raise RuntimeError("Unknown loop type " + str(iter_type))
+ sch.reorder(*s_loops, *r_loops)
+ s_ctr = sch.fuse(*s_loops)
+ r_ctr = sch.fuse(*r_loops)
+
+ if c_loops and not is_loop_c_reduction:
+ s_ctr, inner = sch.split(s_ctr, factors=[None, c_loop_factor])
+ sch.reorder(s_ctr, r_ctr, inner)
+
+ if is_inner_reduction:
+ sch.bind(r_ctr, "threadIdx.x")
+ sch.set_scope(rf, 0, "local")
+ sch.decompose_reduction(rf, sch.get_loops(rf)[2])
+ else:
+ sch.bind(s_ctr, "threadIdx.x")
+ sch.bind(r_ctr, "threadIdx.y")
+ sch.set_scope(rf, 0, "local")
+ sch.decompose_reduction(rf, sch.get_loops(rf)[3])
+
+ if len(block_infos) == 2:
+ sch.set_scope(block, 0, "local")
+ sch.reverse_compute_at(block_infos[1].block_rv,
sch.get_loops(block)[0])
+
+ return sch
diff --git a/python/tvm/dlight/gpu/fallback.py
b/python/tvm/dlight/gpu/fallback.py
index caefc8d563..63033aa7c7 100644
--- a/python/tvm/dlight/gpu/fallback.py
+++ b/python/tvm/dlight/gpu/fallback.py
@@ -21,7 +21,7 @@ from typing import List
from tvm import tir
from tvm.target import Target
-from ..base import BlockInfo, ScheduleRule, try_inline
+from ..base import ScheduleRule, normalize_prim_func, try_inline
def _max_threads_per_block(target: Target) -> int:
@@ -49,21 +49,13 @@ class Fallback(ScheduleRule):
max_threads_per_block = _max_threads_per_block(target)
sch = tir.Schedule(func)
- for block in try_inline(
- sch,
- [
- BlockInfo(
- sch,
- block,
- )
- for block in sch.get_child_blocks(sch.get_block("root"))
- ],
- ):
+ block_infos = try_inline(sch, normalize_prim_func(sch))
+ for block in block_infos:
s_loops: List[tir.schedule.LoopRV] = []
r_loops: List[tir.schedule.LoopRV] = []
o_loops: List[tir.schedule.LoopRV] = []
dom_kind = block.dom_kind()
- block = block.block
+ block = block.block_rv
for loop, iter_type in zip(sch.get_loops(block), dom_kind):
{"S": s_loops, "R": r_loops, "O":
o_loops}[iter_type].append(loop)
diff --git a/python/tvm/dlight/gpu/reduction.py
b/python/tvm/dlight/gpu/reduction.py
index b3cc58c902..bfca76546f 100644
--- a/python/tvm/dlight/gpu/reduction.py
+++ b/python/tvm/dlight/gpu/reduction.py
@@ -20,7 +20,7 @@ from typing import List, Union
from tvm import tir
from tvm.target import Target
-from ..base import BlockInfo, ScheduleRule, try_inline
+from ..base import ScheduleRule, normalize_prim_func,
try_inline_contiguous_spatial
class Reduction(ScheduleRule):
@@ -39,47 +39,31 @@ class Reduction(ScheduleRule):
len_tx = 64
unroll_depth = 64
- def _inline_all_spatial():
- blocks = []
- spatial_blocks = []
- for block in sch.get_child_blocks(sch.get_block("root")):
- block = BlockInfo(sch, block)
- if block.is_spatial():
- spatial_blocks.append(block)
- elif spatial_blocks:
- blocks.extend(try_inline(sch, spatial_blocks))
- blocks.append(block)
- spatial_blocks = []
- else:
- blocks.append(block)
- if spatial_blocks:
- blocks.extend(try_inline(sch, spatial_blocks))
- return blocks
-
sch = tir.Schedule(func)
- blocks = _inline_all_spatial()
- assert len(blocks) > 0
+ block_infos = normalize_prim_func(sch)
+ block_infos = try_inline_contiguous_spatial(sch, block_infos)
+ assert len(block_infos) > 0
- dom_kind = blocks[0].dom_kind()
+ dom_kind = block_infos[0].dom_kind()
num_leading_s = len(dom_kind) - len(dom_kind.lstrip("S"))
num_trailing_r = len(dom_kind) - len(dom_kind.rstrip("R"))
try:
- for block in blocks[1:-1]:
+ for block in block_infos[1:-1]:
assert block.dom_kind() == dom_kind
- assert blocks[-1].is_spatial()
- assert len(blocks[-1].dom_kind()) == len(dom_kind)
+ assert block_infos[-1].is_injective()
+ assert len(block_infos[-1].dom_kind()) == len(dom_kind)
except AssertionError:
print("Mismatch")
return None
- loops = sch.get_loops(blocks[-1].block)
+ loops = sch.get_loops(block_infos[-1].block_rv)
bx = sch.fuse(*loops[:num_leading_s]) # pylint: disable=invalid-name
_, tx = sch.split(loops[-1], [None, len_tx]) # pylint:
disable=invalid-name
sch.bind(bx, "blockIdx.x")
sch.bind(tx, "threadIdx.x")
- for block in reversed(blocks[:-1]):
- block = block.block
+ for block in reversed(block_infos[:-1]):
+ block = block.block_rv
for i, _ in enumerate(sch.get(block).writes):
sch.set_scope(block, buffer_index=i, storage_scope="shared")
sch.compute_at(block, bx, preserve_unit_loops=True)
diff --git a/src/tir/schedule/analysis/analysis.cc
b/src/tir/schedule/analysis/analysis.cc
index 5588458235..bab0d4a3a9 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -328,6 +328,11 @@ bool IsReductionBlock(const ScheduleState& self, const
StmtSRef& block_sref,
return CheckReductionBlockErrorCode(self, block_sref, scope_root_sref) == 0;
}
+TVM_REGISTER_GLOBAL("tir.schedule.IsReductionBlock")
+ .set_body_typed([](Schedule sch, BlockRV block_rv, BlockRV scope_block_rv)
{
+ return IsReductionBlock(sch->state(), sch->GetSRef(block_rv),
sch->GetSRef(scope_block_rv));
+ });
+
void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& scope_root_sref) {
class NotReductionBlockError : public ScheduleError {
@@ -859,6 +864,11 @@ BlockRealize GetBlockRealize(const ScheduleState& self,
const StmtSRef& block_sr
}
}
+TVM_REGISTER_GLOBAL("tir.schedule.GetBlockRealize")
+ .set_body_typed([](Schedule sch, BlockRV block_rv) {
+ return GetBlockRealize(sch->state(), sch->GetSRef(block_rv));
+ });
+
IterVarType GetLoopIterType(const StmtSRef& loop_sref) {
const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
const Var& loop_var = loop->loop_var;
@@ -1459,6 +1469,11 @@ bool IsTrivialBinding(const ScheduleState& self, const
StmtSRef& block_sref) {
return true;
}
+TVM_REGISTER_GLOBAL("tir.schedule.IsTrivialBinding")
+ .set_body_typed([](Schedule sch, BlockRV block_rv) {
+ return IsTrivialBinding(sch->state(), sch->GetSRef(block_rv));
+ });
+
bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef&
block_sref) {
if (HasBeenMultiLevelTiled(block_sref)) {
return false;
diff --git a/src/tir/schedule/primitive/compute_inline.cc
b/src/tir/schedule/primitive/compute_inline.cc
index 453f962aa5..d6be0e5805 100644
--- a/src/tir/schedule/primitive/compute_inline.cc
+++ b/src/tir/schedule/primitive/compute_inline.cc
@@ -177,14 +177,18 @@ class NonSingleProducerError : public ScheduleError {
self, GetRef<Block>(consumer_block), scope_root_sref);
class ProducerFinder : public StmtVisitor {
public:
- static std::vector<Block> GetProducer(const Buffer& buffer, const Block&
scope_block) {
- ProducerFinder finder(buffer);
+ static std::vector<Block> GetProducer(const ScheduleState& self,
+ const StmtSRef& scope_root_sref,
const Buffer& buffer,
+ const Block& scope_block) {
+ ProducerFinder finder(self, scope_root_sref, buffer);
finder(scope_block);
return finder.producer_across_scope_.back();
}
private:
- explicit ProducerFinder(const Buffer& buffer) : buffer_(buffer) {
+ explicit ProducerFinder(const ScheduleState& self, const StmtSRef&
scope_root_sref,
+ const Buffer& buffer)
+ : self_(self), scope_root_sref_(scope_root_sref), buffer_(buffer) {
producer_across_scope_.push_back({});
}
@@ -204,16 +208,23 @@ class NonSingleProducerError : public ScheduleError {
producer_across_scope_.pop_back();
for (const auto& write : node->writes) {
if (write->buffer.same_as(buffer_)) {
+ // Check if the producer block is a complete block
+ StmtSRef producer_block_sref = self_->stmt2ref.at(node);
+ if (!IsCompleteBlock(self_, producer_block_sref,
scope_root_sref_)) {
+ throw NonSingleProducerError(self_->mod, GetRef<Block>(node));
+ }
producer_across_scope_.back().push_back(GetRef<Block>(node));
break;
}
}
}
+ ScheduleState self_;
+ StmtSRef scope_root_sref_;
Buffer buffer_;
std::vector<std::vector<Block>> producer_across_scope_;
};
- std::vector<Block> producer_across_scope =
- ProducerFinder::GetProducer(consumer_buffer,
GetRef<Block>(scope_block));
+ std::vector<Block> producer_across_scope = ProducerFinder::GetProducer(
+ self, scope_root_sref, consumer_buffer, GetRef<Block>(scope_block));
if (producer_across_scope.size() != 1) {
throw NonSingleProducerError(self->mod, GetRef<Block>(consumer_block));
}
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index e4047273b6..35bf7b7669 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -429,5 +429,57 @@ PrimExpr BlockBufferAccessSimplifier::VisitExpr_(const
BufferLoadNode* op) {
return std::move(node);
}
+/******** PrimFunc-level analysis and transformation ********/
+
+Optional<ObjectRef> NormalizePrimFunc(Schedule sch) {
+ BlockRV root_block = sch->GetBlock("root");
+ Array<BlockRV> blocks = sch->GetChildBlocks(root_block);
+ for (const BlockRV& block : blocks) {
+ StmtSRef block_sref = sch->GetSRef(block);
+ Array<StmtSRef> loops = GetLoops(block_sref);
+ Array<PrimExpr> binds = GetBlockRealize(sch->state(),
block_sref)->iter_values;
+ if (loops.size() != binds.size()) {
+ return NullOpt;
+ }
+ for (int i = 0, n = loops.size(); i < n; ++i) {
+ const ForNode* loop = TVM_SREF_TO_FOR(loops[i]);
+ if (binds[i].get() != loop->loop_var.get()) {
+ return NullOpt;
+ }
+ if (!is_zero(loop->min)) {
+ return NullOpt;
+ }
+ }
+ }
+ Array<Array<LoopRV>> block_loops;
+ Array<Array<IterVar>> block_iters;
+ Array<IntImm> block_is_reduction;
+ for (const BlockRV& block : blocks) {
+ Array<IterVar> iters = sch->Get(block)->iter_vars;
+ Array<Var> index_map_inputs;
+ Array<PrimExpr> index_map_outputs;
+ for (const IterVar& iter : sch->Get(block)->iter_vars) {
+ Var var = iter->var.copy_with_suffix("");
+ index_map_inputs.push_back(var);
+ if (!is_one(iter->dom->extent)) {
+ index_map_outputs.push_back(var);
+ }
+ }
+ if (index_map_outputs.empty()) {
+ index_map_outputs.push_back(make_zero(DataType::Int(64)));
+ }
+ sch->TransformBlockLayout(block, IndexMap(index_map_inputs,
index_map_outputs));
+ block_loops.push_back(sch->GetLoops(block));
+ block_iters.push_back(sch->Get(block)->iter_vars);
+ bool is_reduction = IsReductionBlock(sch->state(), //
+ sch->GetSRef(block), //
+ sch->GetSRef(root_block));
+ block_is_reduction.push_back(Bool(is_reduction));
+ }
+ return Array<ObjectRef>{blocks, block_loops, block_iters,
block_is_reduction};
+}
+
+TVM_REGISTER_GLOBAL("tir.schedule.NormalizePrimFunc").set_body_typed(NormalizePrimFunc);
+
} // namespace tir
} // namespace tvm
diff --git a/tests/python/dlight/test_gpu_decode_gemv.py
b/tests/python/dlight/test_gpu_decode_gemv.py
new file mode 100644
index 0000000000..46232a461e
--- /dev/null
+++ b/tests/python/dlight/test_gpu_decode_gemv.py
@@ -0,0 +1,432 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import tvm
+from tvm import dlight as dl
+from tvm.ir import assert_structural_equal
+from tvm.script import ir as I
+from tvm.script import tir as T
+from tvm.target import Target
+
+
+def test_decode_gemv_1():
+ # NK layout + K as decode dim
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128),
"float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096),
"float16")):
+ T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ B = T.alloc_buffer((4096, 4096), "float16")
+ for i, j in T.grid(4096, 4096):
+ with T.block("decode"):
+ v_i, v_j = T.axis.remap("SS", [i, j])
+ T.reads(W[v_i, v_j // 8], S[v_i, v_j // 32])
+ T.writes(B[v_i, v_j])
+ B[v_i, v_j] = (T.Cast("float16",
T.bitwise_and(T.shift_right(W[v_i, v_j // 8], T.Cast("uint32", v_j % 8) *
T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i, v_j // 32]
+ for i0, i1, i2, k in T.grid(1, 1, 4096, 4096):
+ with T.block("matmul"):
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2,
k])
+ T.reads(V[v_i0, v_i1, v_k], B[v_i2, v_k])
+ T.writes(C[v_i0, v_i1, v_i2])
+ with T.init():
+ C[v_i0, v_i1, v_i2] = T.float16(0)
+ C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + V[v_i0, v_i1,
v_k] * B[v_i2, v_k]
+
+
+ @I.ir_module
+ class After:
+ @T.prim_func
+ def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128),
"float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096),
"float16")):
+ T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1,
"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ C_rf_local = T.alloc_buffer((256, 1, 1, 4096), "float16",
scope="local")
+ for i2_i0_i1_fused in T.thread_binding(4096, thread="blockIdx.x"):
+ for k_0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
+ with T.block("matmul_rf_init"):
+ vk_0_fused_1 = T.axis.spatial(256, k_0_fused_1)
+ v_i2 = T.axis.spatial(4096, i2_i0_i1_fused)
+ C_rf_local[vk_0_fused_1, 0, 0, v_i2] = T.float16(0)
+ for k_0_fused_0, k_1 in T.grid(2, 8):
+ with T.block("matmul_rf_update"):
+ vk_0_fused_1 = T.axis.spatial(256, k_0_fused_1)
+ v_i2, vk_0_fused_0, vk_1 = T.axis.remap("SRR",
[i2_i0_i1_fused, k_0_fused_0, k_1])
+ C_rf_local[vk_0_fused_1, 0, 0, v_i2] =
C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, vk_0_fused_0 * 2048 +
vk_0_fused_1 * 8 + vk_1] * ((T.Cast("float16",
T.bitwise_and(T.shift_right(W[v_i2, (vk_0_fused_0 * 2048 + vk_0_fused_1 * 8 +
vk_1) // 8], T.Cast("uint32", (vk_0_fused_0 * 2048 + vk_0_fused_1 * 8 + vk_1) %
8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i2, (vk_0_fused_0 *
2048 + vk_0_fused_1 * 8 + vk_1) // 32])
+ for ax1_ax2_ax3_fused in range(1):
+ for ax0_fused in T.thread_binding(256,
thread="threadIdx.x"):
+ with T.block("matmul"):
+ vk_0_fused_1 = T.axis.reduce(256, ax0_fused)
+ v_i2 = T.axis.spatial(4096, i2_i0_i1_fused)
+ with T.init():
+ C[0, 0, v_i2] = T.float16(0)
+ C[0, 0, v_i2] = C[0, 0, v_i2] +
C_rf_local[vk_0_fused_1, 0, 0, v_i2]
+ # fmt: on
+
+ target = Target("nvidia/geforce-rtx-3090-ti")
+ with target:
+ mod = dl.ApplyDefaultSchedule(dl.gpu.DecodeGEMV())(Before) # pylint:
disable=not-callable
+ assert_structural_equal(mod, After)
+
+
+def test_decode_gemv_2():
+ # KN layout + K as decode dim
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def func(W: T.Buffer((512, 4096), "uint32"), S: T.Buffer((128, 4096),
"float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096),
"float16")):
+ T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ B = T.alloc_buffer((4096, 4096), "float16")
+ for i, j in T.grid(4096, 4096):
+ with T.block("decode"):
+ v_i, v_j = T.axis.remap("SS", [i, j])
+ T.reads(W[v_i // 8, v_j], S[v_i // 32, v_j])
+ T.writes(B[v_i, v_j])
+ B[v_i, v_j] = (T.Cast("float16",
T.bitwise_and(T.shift_right(W[v_i // 8, v_j], T.Cast("uint32", v_i % 8) *
T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i // 32, v_j]
+ for i0, i1, i2, k in T.grid(1, 1, 4096, 4096):
+ with T.block("matmul"):
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2,
k])
+ T.reads(V[v_i0, v_i1, v_k], B[v_k, v_i2])
+ T.writes(C[v_i0, v_i1, v_i2])
+ with T.init():
+ C[v_i0, v_i1, v_i2] = T.float16(0)
+ C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + V[v_i0, v_i1,
v_k] * B[v_k, v_i2]
+
+
+ @I.ir_module
+ class After:
+ @T.prim_func
+ def func(W: T.Buffer((512, 4096), "uint32"), S: T.Buffer((128, 4096),
"float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096),
"float16")):
+ T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1,
"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ C_rf_local = T.alloc_buffer((16, 1, 1, 4096), "float16",
scope="local")
+ for i2_i0_i1_fused_0 in T.thread_binding(256, thread="blockIdx.x"):
+ for i2_i0_i1_fused_1 in T.thread_binding(16,
thread="threadIdx.x"):
+ for k_0_fused_1 in T.thread_binding(16,
thread="threadIdx.y"):
+ with T.block("matmul_rf_init"):
+ vk_0_fused_1 = T.axis.spatial(16, k_0_fused_1)
+ v_i2 = T.axis.spatial(4096, i2_i0_i1_fused_0 * 16
+ i2_i0_i1_fused_1)
+ C_rf_local[vk_0_fused_1, 0, 0, v_i2] = T.float16(0)
+ for k_0_fused_0, k_1 in T.grid(32, 8):
+ with T.block("matmul_rf_update"):
+ vk_0_fused_1 = T.axis.spatial(16, k_0_fused_1)
+ v_i2 = T.axis.spatial(4096, i2_i0_i1_fused_0 *
16 + i2_i0_i1_fused_1)
+ vk_0_fused_0, vk_1 = T.axis.remap("RR",
[k_0_fused_0, k_1])
+ C_rf_local[vk_0_fused_1, 0, 0, v_i2] =
C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, vk_0_fused_0 * 128 +
vk_0_fused_1 * 8 + vk_1] * ((T.Cast("float16",
T.bitwise_and(T.shift_right(W[(vk_0_fused_0 * 128 + vk_0_fused_1 * 8 + vk_1) //
8, v_i2], T.Cast("uint32", (vk_0_fused_0 * 128 + vk_0_fused_1 * 8 + vk_1) % 8)
* T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[(vk_0_fused_0 * 128 +
vk_0_fused_1 * 8 + vk_1) // 32, v_i2])
+ for ax1_ax2_ax3_fused in T.thread_binding(16,
thread="threadIdx.x"):
+ for ax0_fused in T.thread_binding(16,
thread="threadIdx.y"):
+ with T.block("matmul"):
+ vk_0_fused_1 = T.axis.reduce(16, ax0_fused)
+ v_i2 = T.axis.spatial(4096, i2_i0_i1_fused_0 * 16
+ ax1_ax2_ax3_fused)
+ with T.init():
+ C[0, 0, v_i2] = T.float16(0)
+ C[0, 0, v_i2] = C[0, 0, v_i2] +
C_rf_local[vk_0_fused_1, 0, 0, v_i2]
+
+ # fmt: on
+
+ target = Target("nvidia/geforce-rtx-3090-ti")
+ with target:
+ mod = dl.ApplyDefaultSchedule(dl.gpu.DecodeGEMV())(Before) # pylint:
disable=not-callable
+ assert_structural_equal(mod, After)
+
+
+def test_decode_gemv_3():
+ # NK layout + N as decode dim
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def func(W: T.Buffer((512, 4096), "uint32"), S: T.Buffer((128, 4096),
"float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096),
"float16")):
+ T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ B = T.alloc_buffer((4096, 4096), "float16")
+ for i, j in T.grid(4096, 4096):
+ with T.block("decode"):
+ v_i, v_j = T.axis.remap("SS", [i, j])
+ T.reads(W[v_i // 8, v_j], S[v_i // 32, v_j])
+ T.writes(B[v_i, v_j])
+ B[v_i, v_j] = (T.Cast("float16",
T.bitwise_and(T.shift_right(W[v_i // 8, v_j], T.Cast("uint32", v_i % 8) *
T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i // 32, v_j]
+ for i0, i1, i2, k in T.grid(1, 1, 4096, 4096):
+ with T.block("matmul"):
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2,
k])
+ T.reads(V[v_i0, v_i1, v_k], B[v_i2, v_k])
+ T.writes(C[v_i0, v_i1, v_i2])
+ with T.init():
+ C[v_i0, v_i1, v_i2] = T.float16(0)
+ C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + V[v_i0, v_i1,
v_k] * B[v_i2, v_k]
+
+
+ @I.ir_module
+ class After:
+ @T.prim_func
+ def func(W: T.Buffer((512, 4096), "uint32"), S: T.Buffer((128, 4096),
"float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096),
"float16")):
+ T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1,
"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ C_rf_local = T.alloc_buffer((256, 1, 1, 4096), "float16",
scope="local")
+ for i2_0_i0_i1_fused in T.thread_binding(512, thread="blockIdx.x"):
+ for k_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
+ for i2_1_init in range(8):
+ with T.block("matmul_rf_init"):
+ vk_fused_1 = T.axis.spatial(256, k_fused_1)
+ v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused * 8 +
i2_1_init)
+ C_rf_local[vk_fused_1, 0, 0, v_i2] = T.float16(0)
+ for k_fused_0, i2_1 in T.grid(16, 8):
+ with T.block("matmul_rf_update"):
+ vk_fused_1 = T.axis.spatial(256, k_fused_1)
+ v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused * 8 +
i2_1)
+ vk_fused_0 = T.axis.reduce(16, k_fused_0)
+ C_rf_local[vk_fused_1, 0, 0, v_i2] =
C_rf_local[vk_fused_1, 0, 0, v_i2] + V[0, 0, vk_fused_0 * 256 + vk_fused_1] *
((T.Cast("float16", T.bitwise_and(T.shift_right(W[v_i2 // 8, vk_fused_0 * 256 +
vk_fused_1], T.Cast("uint32", v_i2 % 8) * T.uint32(4)), T.uint32(15))) -
T.float16(7)) * S[v_i2 // 32, vk_fused_0 * 256 + vk_fused_1])
+ for ax1_ax2_ax3_fused_0 in range(1):
+ for ax0_fused in T.thread_binding(256,
thread="threadIdx.x"):
+ for ax1_ax2_ax3_fused_1 in range(8):
+ with T.block("matmul"):
+ vk_fused_1 = T.axis.reduce(256, ax0_fused)
+ v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused *
8 + ax1_ax2_ax3_fused_0 * 8 + ax1_ax2_ax3_fused_1)
+ with T.init():
+ C[0, 0, v_i2] = T.float16(0)
+ C[0, 0, v_i2] = C[0, 0, v_i2] +
C_rf_local[vk_fused_1, 0, 0, v_i2]
+
+ # fmt: on
+
+ target = Target("nvidia/geforce-rtx-3090-ti")
+ with target:
+ mod = dl.ApplyDefaultSchedule(dl.gpu.DecodeGEMV())(Before) # pylint:
disable=not-callable
+ assert_structural_equal(mod, After)
+
+
+def test_decode_gemv_4():
+ # KN layout + N as decode dim
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128),
"float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096),
"float16")):
+ T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ B = T.alloc_buffer((4096, 4096), "float16")
+ for i, j in T.grid(4096, 4096):
+ with T.block("decode"):
+ v_i, v_j = T.axis.remap("SS", [i, j])
+ T.reads(W[v_i, v_j // 8], S[v_i, v_j // 32])
+ T.writes(B[v_i, v_j])
+ B[v_i, v_j] = (T.Cast("float16",
T.bitwise_and(T.shift_right(W[v_i, v_j // 8], T.Cast("uint32", v_j % 8) *
T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i, v_j // 32]
+ for i0, i1, i2, k in T.grid(1, 1, 4096, 4096):
+ with T.block("matmul"):
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2,
k])
+ T.reads(V[v_i0, v_i1, v_k], B[v_k, v_i2])
+ T.writes(C[v_i0, v_i1, v_i2])
+ with T.init():
+ C[v_i0, v_i1, v_i2] = T.float16(0)
+ C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + V[v_i0, v_i1,
v_k] * B[v_k, v_i2]
+
+
+ @I.ir_module
+ class After:
+ @T.prim_func
+ def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128),
"float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096),
"float16")):
+ T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1,
"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ C_rf_local = T.alloc_buffer((16, 1, 1, 4096), "float16",
scope="local")
+ for i2_0_i0_i1_fused_0 in T.thread_binding(32,
thread="blockIdx.x"):
+ for i2_0_i0_i1_fused_1 in T.thread_binding(16,
thread="threadIdx.x"):
+ for k_fused_1 in T.thread_binding(16,
thread="threadIdx.y"):
+ for i2_1_init in range(8):
+ with T.block("matmul_rf_init"):
+ vk_fused_1 = T.axis.spatial(16, k_fused_1)
+ v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused_0
* 128 + i2_0_i0_i1_fused_1 * 8 + i2_1_init)
+ C_rf_local[vk_fused_1, 0, 0, v_i2] =
T.float16(0)
+ for k_fused_0, i2_1 in T.grid(256, 8):
+ with T.block("matmul_rf_update"):
+ vk_fused_1 = T.axis.spatial(16, k_fused_1)
+ v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused_0
* 128 + i2_0_i0_i1_fused_1 * 8 + i2_1)
+ vk_fused_0 = T.axis.reduce(256, k_fused_0)
+ C_rf_local[vk_fused_1, 0, 0, v_i2] =
C_rf_local[vk_fused_1, 0, 0, v_i2] + V[0, 0, vk_fused_0 * 16 + vk_fused_1] *
((T.Cast("float16", T.bitwise_and(T.shift_right(W[vk_fused_0 * 16 + vk_fused_1,
v_i2 // 8], T.Cast("uint32", v_i2 % 8) * T.uint32(4)), T.uint32(15))) -
T.float16(7)) * S[vk_fused_0 * 16 + vk_fused_1, v_i2 // 32])
+ for ax1_ax2_ax3_fused_0 in T.thread_binding(16,
thread="threadIdx.x"):
+ for ax0_fused in T.thread_binding(16,
thread="threadIdx.y"):
+ for ax1_ax2_ax3_fused_1 in range(8):
+ with T.block("matmul"):
+ vk_fused_1 = T.axis.reduce(16, ax0_fused)
+ v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused_0
* 128 + ax1_ax2_ax3_fused_0 * 8 + ax1_ax2_ax3_fused_1)
+ with T.init():
+ C[0, 0, v_i2] = T.float16(0)
+ C[0, 0, v_i2] = C[0, 0, v_i2] +
C_rf_local[vk_fused_1, 0, 0, v_i2]
+
+ # fmt: on
+
+ target = Target("nvidia/geforce-rtx-3090-ti")
+ with target:
+ mod = dl.ApplyDefaultSchedule(dl.gpu.DecodeGEMV())(Before) # pylint:
disable=not-callable
+ assert_structural_equal(mod, After)
+
+
+def test_decode_gemv_sigmoid():
+ # NK layout + K as decode dim
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128),
"float16"), V: T.Buffer((1, 1, 4096), "float16"), D: T.Buffer((1, 1, 4096),
"float16")):
+ T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ B = T.alloc_buffer((4096, 4096), "float16")
+ C = T.alloc_buffer((1, 1, 4096), "float16")
+ for i, j in T.grid(4096, 4096):
+ with T.block("decode"):
+ v_i, v_j = T.axis.remap("SS", [i, j])
+ T.reads(W[v_i, v_j // 8], S[v_i, v_j // 32])
+ T.writes(B[v_i, v_j])
+ B[v_i, v_j] = (T.Cast("float16",
T.bitwise_and(T.shift_right(W[v_i, v_j // 8], T.Cast("uint32", v_j % 8) *
T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i, v_j // 32]
+ for i0, i1, i2, k in T.grid(1, 1, 4096, 4096):
+ with T.block("matmul"):
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2,
k])
+ T.reads(V[v_i0, v_i1, v_k], B[v_i2, v_k])
+ T.writes(C[v_i0, v_i1, v_i2])
+ with T.init():
+ C[v_i0, v_i1, v_i2] = T.float16(0)
+ C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + V[v_i0, v_i1,
v_k] * B[v_i2, v_k]
+ for i0, i1, i2 in T.grid(1, 1, 4096):
+ with T.block("sigmoid"):
+ v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+ T.reads(C[v_i0, v_i1, v_i2])
+ T.writes(D[v_i0, v_i1, v_i2])
+ D[v_i0, v_i1, v_i2] = T.sigmoid(C[v_i0, v_i1, v_i2])
+
+ @I.ir_module
+ class After:
+ @T.prim_func
+ def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128),
"float16"), V: T.Buffer((1, 1, 4096), "float16"), D: T.Buffer((1, 1, 4096),
"float16")):
+ T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1,
"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ C_local = T.alloc_buffer((1, 1, 4096), "float16", scope="local")
+ C_rf_local = T.alloc_buffer((256, 1, 1, 4096), "float16",
scope="local")
+ for i2_i0_i1_fused in T.thread_binding(4096, thread="blockIdx.x"):
+ for k_0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
+ with T.block("matmul_rf_init"):
+ vk_0_fused_1 = T.axis.spatial(256, k_0_fused_1)
+ v_i2 = T.axis.spatial(4096, i2_i0_i1_fused)
+ C_rf_local[vk_0_fused_1, 0, 0, v_i2] = T.float16(0)
+ for k_0_fused_0, k_1 in T.grid(2, 8):
+ with T.block("matmul_rf_update"):
+ vk_0_fused_1 = T.axis.spatial(256, k_0_fused_1)
+ v_i2, vk_0_fused_0, vk_1 = T.axis.remap("SRR",
[i2_i0_i1_fused, k_0_fused_0, k_1])
+ C_rf_local[vk_0_fused_1, 0, 0, v_i2] =
C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, vk_0_fused_0 * 2048 +
vk_0_fused_1 * 8 + vk_1] * ((T.Cast("float16",
T.bitwise_and(T.shift_right(W[v_i2, (vk_0_fused_0 * 2048 + vk_0_fused_1 * 8 +
vk_1) // 8], T.Cast("uint32", (vk_0_fused_0 * 2048 + vk_0_fused_1 * 8 + vk_1) %
8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i2, (vk_0_fused_0 *
2048 + vk_0_fused_1 * 8 + vk_1) // 32])
+ for ax1_ax2_ax3_fused in range(1):
+ for ax0_fused in T.thread_binding(256,
thread="threadIdx.x"):
+ with T.block("matmul"):
+ vk_0_fused_1 = T.axis.reduce(256, ax0_fused)
+ v_i2 = T.axis.spatial(4096, i2_i0_i1_fused)
+ with T.init():
+ C_local[0, 0, v_i2] = T.float16(0)
+ C_local[0, 0, v_i2] = C_local[0, 0, v_i2] +
C_rf_local[vk_0_fused_1, 0, 0, v_i2]
+ with T.block("sigmoid"):
+ v_i2 = T.axis.spatial(4096, i2_i0_i1_fused)
+ D[0, 0, v_i2] = T.sigmoid(C_local[0, 0, v_i2])
+
+ # fmt: on
+
+ target = Target("nvidia/geforce-rtx-3090-ti")
+ with target:
+ mod = dl.ApplyDefaultSchedule(dl.gpu.DecodeGEMV())(Before) # pylint:
disable=not-callable
+ assert_structural_equal(mod, After)
+
+
+def test_decode_gemv_1_fp32():
+ # NK layout + K as decode dim
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128),
"float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096),
"float16")):
+ T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ B = T.alloc_buffer((4096, 4096), "float16")
+ C_fp32 = T.alloc_buffer((1, 1, 4096), "float32")
+ for i, j in T.grid(4096, 4096):
+ with T.block("decode"):
+ v_i, v_j = T.axis.remap("SS", [i, j])
+ T.reads(W[v_i, v_j // 8], S[v_i, v_j // 32])
+ T.writes(B[v_i, v_j])
+ B[v_i, v_j] = (T.Cast("float16",
T.bitwise_and(T.shift_right(W[v_i, v_j // 8], T.Cast("uint32", v_j % 8) *
T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i, v_j // 32]
+ for i0, i1, i2, k in T.grid(1, 1, 4096, 4096):
+ with T.block("matmul"):
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2,
k])
+ T.reads(V[v_i0, v_i1, v_k], B[v_i2, v_k])
+ T.writes(C_fp32[v_i0, v_i1, v_i2])
+ with T.init():
+ C_fp32[v_i0, v_i1, v_i2] = T.float16(0)
+ C_fp32[v_i0, v_i1, v_i2] = C_fp32[v_i0, v_i1, v_i2] +
T.Cast("float32", V[v_i0, v_i1, v_k]) * T.Cast("float32", B[v_i2, v_k])
+ for i0, i1, i2 in T.grid(1, 1, 4096):
+ with T.block("cast"):
+ v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+ T.reads(C_fp32[v_i0, v_i1, v_i2])
+ T.writes(C[v_i0, v_i1, v_i2])
+ C[v_i0, v_i1, v_i2] = T.Cast("float16", C_fp32[v_i0, v_i1,
v_i2])
+
+
+ @I.ir_module
+ class After:
+ @T.prim_func
+ def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128),
"float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096),
"float16")):
+ T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1,
"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ C_fp32_local = T.alloc_buffer((1, 1, 4096), scope="local")
+ C_fp32_rf_local = T.alloc_buffer((256, 1, 1, 4096), scope="local")
+ for ax0_fused in T.thread_binding(4096, thread="blockIdx.x"):
+ for ax1_0_fused_1 in T.thread_binding(256,
thread="threadIdx.x"):
+ with T.block("matmul_rf_init"):
+ vax1_0_fused_1, v0 = T.axis.remap("SS",
[ax1_0_fused_1, ax0_fused])
+ T.reads()
+ T.writes(C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0])
+ C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] =
T.float32(0)
+ for ax1_0_fused_0, ax1_1 in T.grid(2, 8):
+ with T.block("matmul_rf_update"):
+ vax1_0_fused_1, v0, vax1_0_fused_0, vax1_1 =
T.axis.remap("SSRR", [ax1_0_fused_1, ax0_fused, ax1_0_fused_0, ax1_1])
+ T.reads(C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0],
V[0, 0, vax1_0_fused_0 * 2048 + vax1_0_fused_1 * 8 + vax1_1], W[v0,
(vax1_0_fused_0 * 2048 + vax1_0_fused_1 * 8 + vax1_1) // 8], S[v0,
(vax1_0_fused_0 * 2048 + vax1_0_fused_1 * 8 + vax1_1) // 32])
+ T.writes(C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0])
+ C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] =
C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] + T.Cast("float32", V[0, 0,
vax1_0_fused_0 * 2048 + vax1_0_fused_1 * 8 + vax1_1]) * T.Cast("float32",
(T.Cast("float16", T.bitwise_and(T.shift_right(W[v0, (vax1_0_fused_0 * 2048 +
vax1_0_fused_1 * 8 + vax1_1) // 8], T.Cast("uint32", (vax1_0_fused_0 * 2048 +
vax1_0_fused_1 * 8 + vax1_1) % 8) * T.uint32(4)), T.uint32(15))) -
T.float16(7)) * S[v0, (vax1_0_fused_0 * 2048 + vax1_0 [...]
+ for ax1_fused in range(1):
+ for ax0_fused_1 in T.thread_binding(256,
thread="threadIdx.x"):
+ with T.block("matmul"):
+ vax1_0_fused_1, v0 = T.axis.remap("RS",
[ax0_fused_1, ax0_fused])
+ T.reads(C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0])
+ T.writes(C_fp32_local[0, 0, v0])
+ with T.init():
+ C_fp32_local[0, 0, v0] = T.float32(0)
+ C_fp32_local[0, 0, v0] = C_fp32_local[0, 0, v0] +
C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0]
+ with T.block("cast"):
+ v0 = T.axis.spatial(4096, ax0_fused)
+ T.reads(C_fp32_local[0, 0, v0])
+ T.writes(C[0, 0, v0])
+ C[0, 0, v0] = T.Cast("float16", C_fp32_local[0, 0, v0])
+
+ # fmt: on
+
+ target = Target("nvidia/geforce-rtx-3090-ti")
+ with target:
+ mod = dl.ApplyDefaultSchedule(dl.gpu.DecodeGEMV())(Before) # pylint:
disable=not-callable
+ assert_structural_equal(mod, After)
+
+
+if __name__ == "__main__":
+ test_decode_gemv_1()
+ test_decode_gemv_2()
+ test_decode_gemv_3()
+ test_decode_gemv_4()
+ test_decode_gemv_sigmoid()
+ test_decode_gemv_1_fp32()
diff --git a/tests/python/dlight/test_gpu_fallback.py
b/tests/python/dlight/test_gpu_fallback.py
index 8d20169c53..38e9a391dc 100644
--- a/tests/python/dlight/test_gpu_fallback.py
+++ b/tests/python/dlight/test_gpu_fallback.py
@@ -48,16 +48,13 @@ def test_fallback():
C: T.Buffer((1, 1, 4096), "float16"),
):
T.func_attr({"tir.is_scheduled": 1})
- # with T.block("root"):
- for i_j_k_fused_0 in T.thread_binding(4, thread="blockIdx.x"):
- for i_j_k_fused_1 in T.thread_binding(1024,
thread="threadIdx.x"):
+ for ax0_fused_0 in T.thread_binding(4, thread="blockIdx.x"):
+ for ax0_fused_1 in T.thread_binding(1024,
thread="threadIdx.x"):
with T.block("T_reshape"):
- vi = T.axis.spatial(1, 0)
- vj = T.axis.spatial(1, 0)
- vk = T.axis.spatial(4096, i_j_k_fused_0 * 1024 +
i_j_k_fused_1)
- T.reads(A[0, vk % 4096 // 128, 0, vk % 128])
- T.writes(C[vi, vj, vk])
- C[vi, vj, vk] = A[0, vk % 4096 // 128, 0, vk % 128]
+ v0 = T.axis.spatial(4096, ax0_fused_0 * 1024 +
ax0_fused_1)
+ T.reads(A[0, v0 // 128, 0, v0 % 128])
+ T.writes(C[0, 0, v0])
+ C[0, 0, v0] = A[0, v0 // 128, 0, v0 % 128]
target = Target("nvidia/geforce-rtx-3090-ti")
with target:
diff --git a/tests/python/dlight/test_gpu_reduction.py
b/tests/python/dlight/test_gpu_reduction.py
index 99307093c8..ce3a05e771 100644
--- a/tests/python/dlight/test_gpu_reduction.py
+++ b/tests/python/dlight/test_gpu_reduction.py
@@ -93,44 +93,41 @@ def test_softmax():
# with T.block("root"):
T_softmax_maxelem_shared = T.alloc_buffer((T.int64(1),
T.int64(32), n), scope="shared")
T_softmax_expsum_shared = T.alloc_buffer((T.int64(1), T.int64(32),
n), scope="shared")
- for i0_i1_i2_fused in T.thread_binding(n * T.int64(32),
thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 256,
"pragma_unroll_explicit": 1}):
- for ax0, ax1, ax2, ax3_fused_0 in T.grid(T.int64(1),
T.int64(1), T.int64(1), (m + T.int64(255)) // T.int64(256)):
- for ax3_fused_1 in T.thread_binding(T.int64(256),
thread="threadIdx.x"):
+ for ax0_ax1_fused in T.thread_binding(n * T.int64(32),
thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 256,
"pragma_unroll_explicit": 1}):
+ for ax0, ax1, ax2_fused_0 in T.grid(T.int64(1), T.int64(1), (m
+ T.int64(255)) // T.int64(256)):
+ for ax2_fused_1 in T.thread_binding(T.int64(256),
thread="threadIdx.x"):
with T.block("T_softmax_maxelem"):
- v_i0 = T.axis.spatial(T.int64(1), ax0)
- v_i1 = T.axis.spatial(T.int64(32), i0_i1_i2_fused
// n + ax1)
- v_i2 = T.axis.spatial(n, i0_i1_i2_fused % n + ax2)
- v_k = T.axis.reduce(m, ax3_fused_0 * T.int64(256)
+ ax3_fused_1)
- T.where(T.int64(0) <= i0_i1_i2_fused // n and
i0_i1_i2_fused // n < T.int64(32) and T.int64(0) <= i0_i1_i2_fused % n and
i0_i1_i2_fused % n < n and ax3_fused_0 * T.int64(256) + ax3_fused_1 < m)
- T.reads(lv44[v_i0, v_i1, v_i2, v_k])
- T.writes(T_softmax_maxelem_shared[v_i0, v_i1,
v_i2])
+ v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused //
n + ax0)
+ v1 = T.axis.spatial(n, ax0_ax1_fused % n + ax1)
+ v2 = T.axis.reduce(m, ax2_fused_0 * T.int64(256) +
ax2_fused_1)
+ T.where(T.int64(0) <= ax0_ax1_fused // n and
ax0_ax1_fused // n < T.int64(32) and T.int64(0) <= ax0_ax1_fused % n and
ax0_ax1_fused % n < n and ax2_fused_0 * T.int64(256) + ax2_fused_1 < m)
+ T.reads(lv44[T.int64(0), v0, v1, v2])
+ T.writes(T_softmax_maxelem_shared[T.int64(0), v0,
v1])
with T.init():
- T_softmax_maxelem_shared[v_i0, v_i1, v_i2] =
T.float32(-3.4028234663852886e+38)
- T_softmax_maxelem_shared[v_i0, v_i1, v_i2] =
T.max(T_softmax_maxelem_shared[v_i0, v_i1, v_i2], lv44[v_i0, v_i1, v_i2, v_k])
- for ax0, ax1, ax2, ax3_fused_0 in T.grid(T.int64(1),
T.int64(1), T.int64(1), (m + T.int64(255)) // T.int64(256)):
- for ax3_fused_1 in T.thread_binding(T.int64(256),
thread="threadIdx.x"):
+ T_softmax_maxelem_shared[T.int64(0), v0, v1] =
T.float32(-3.4028234663852886e+38)
+ T_softmax_maxelem_shared[T.int64(0), v0, v1] =
T.max(T_softmax_maxelem_shared[T.int64(0), v0, v1], lv44[T.int64(0), v0, v1,
v2])
+ for ax0, ax1, ax2_fused_0 in T.grid(T.int64(1), T.int64(1), (m
+ T.int64(255)) // T.int64(256)):
+ for ax2_fused_1 in T.thread_binding(T.int64(256),
thread="threadIdx.x"):
with T.block("T_softmax_expsum"):
- v_i0 = T.axis.spatial(T.int64(1), ax0)
- v_i1 = T.axis.spatial(T.int64(32), i0_i1_i2_fused
// n + ax1)
- v_i2 = T.axis.spatial(n, i0_i1_i2_fused % n + ax2)
- v_k = T.axis.reduce(m, ax3_fused_0 * T.int64(256)
+ ax3_fused_1)
- T.where(T.int64(0) <= i0_i1_i2_fused // n and
i0_i1_i2_fused // n < T.int64(32) and T.int64(0) <= i0_i1_i2_fused % n and
i0_i1_i2_fused % n < n and ax3_fused_0 * T.int64(256) + ax3_fused_1 < m)
- T.reads(lv44[v_i0, v_i1, v_i2, v_k],
T_softmax_maxelem_shared[v_i0, v_i1, v_i2])
- T.writes(T_softmax_expsum_shared[v_i0, v_i1, v_i2])
+ v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused //
n + ax0)
+ v1 = T.axis.spatial(n, ax0_ax1_fused % n + ax1)
+ v2 = T.axis.reduce(m, ax2_fused_0 * T.int64(256) +
ax2_fused_1)
+ T.where(T.int64(0) <= ax0_ax1_fused // n and
ax0_ax1_fused // n < T.int64(32) and T.int64(0) <= ax0_ax1_fused % n and
ax0_ax1_fused % n < n and ax2_fused_0 * T.int64(256) + ax2_fused_1 < m)
+ T.reads(lv44[T.int64(0), v0, v1, v2],
T_softmax_maxelem_shared[T.int64(0), v0, v1])
+ T.writes(T_softmax_expsum_shared[T.int64(0), v0,
v1])
with T.init():
- T_softmax_expsum_shared[v_i0, v_i1, v_i2] =
T.float32(0)
- T_softmax_expsum_shared[v_i0, v_i1, v_i2] =
T_softmax_expsum_shared[v_i0, v_i1, v_i2] + T.exp(lv44[v_i0, v_i1, v_i2, v_k] -
T_softmax_maxelem_shared[v_i0, v_i1, v_i2])
- for i3_0 in range((m + T.int64(255)) // T.int64(256)):
- for i3_1 in T.thread_binding(T.int64(256),
thread="threadIdx.x"):
+ T_softmax_expsum_shared[T.int64(0), v0, v1] =
T.float32(0)
+ T_softmax_expsum_shared[T.int64(0), v0, v1] =
T_softmax_expsum_shared[T.int64(0), v0, v1] + T.exp(lv44[T.int64(0), v0, v1,
v2] - T_softmax_maxelem_shared[T.int64(0), v0, v1])
+ for ax2_0 in range((m + T.int64(255)) // T.int64(256)):
+ for ax2_1 in T.thread_binding(T.int64(256),
thread="threadIdx.x"):
with T.block("compute"):
- v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
- v_i1 = T.axis.spatial(T.int64(32), i0_i1_i2_fused
// n)
- v_i2 = T.axis.spatial(n, i0_i1_i2_fused % n)
- v_i3 = T.axis.spatial(m, i3_0 * T.int64(256) +
i3_1)
- T.where(i3_0 * T.int64(256) + i3_1 < m)
- T.reads(lv44[v_i0, v_i1, v_i2, v_i3],
T_softmax_maxelem_shared[v_i0, v_i1, v_i2], T_softmax_expsum_shared[v_i0, v_i1,
v_i2])
- T.writes(var_compute_intermediate[v_i0, v_i1,
v_i2, v_i3])
- var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] =
T.Cast("float16", T.exp(lv44[v_i0, v_i1, v_i2, v_i3] -
T_softmax_maxelem_shared[v_i0, v_i1, v_i2]) / T_softmax_expsum_shared[v_i0,
v_i1, v_i2])
+ v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused //
n)
+ v1 = T.axis.spatial(n, ax0_ax1_fused % n)
+ v2 = T.axis.spatial(m, ax2_0 * T.int64(256) +
ax2_1)
+ T.where(ax2_0 * T.int64(256) + ax2_1 < m)
+ T.reads(lv44[T.int64(0), v0, v1, v2],
T_softmax_maxelem_shared[T.int64(0), v0, v1],
T_softmax_expsum_shared[T.int64(0), v0, v1])
+ T.writes(var_compute_intermediate[T.int64(0), v0,
v1, v2])
+ var_compute_intermediate[T.int64(0), v0, v1, v2] =
T.Cast("float16", T.exp(lv44[T.int64(0), v0, v1, v2] -
T_softmax_maxelem_shared[T.int64(0), v0, v1]) /
T_softmax_expsum_shared[T.int64(0), v0, v1])
# fmt: on
_check(Before, After)
@@ -185,31 +182,29 @@ def test_layer_norm():
# with T.block("root"):
A_red_temp_v0_shared = T.alloc_buffer((T.int64(1), n),
scope="shared")
A_red_temp_v1_shared = T.alloc_buffer((T.int64(1), n),
scope="shared")
- for i0_i1_fused in T.thread_binding(n, thread="blockIdx.x",
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
- for ax0, ax1, ax2_fused_0 in T.grid(T.int64(1), T.int64(1),
T.int64(10)):
- for ax2_fused_1 in T.thread_binding(T.int64(256),
thread="threadIdx.x"):
+ for ax0_fused in T.thread_binding(n, thread="blockIdx.x",
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+ for ax0, ax1_fused_0 in T.grid(T.int64(1), T.int64(10)):
+ for ax1_fused_1 in T.thread_binding(T.int64(256),
thread="threadIdx.x"):
with T.block("A_red_temp"):
- v_ax0 = T.axis.spatial(T.int64(1), ax0)
- v_ax1 = T.axis.spatial(n, i0_i1_fused + ax1)
- v_k2 = T.axis.reduce(T.int64(2560), ax2_fused_0 *
T.int64(256) + ax2_fused_1)
- T.reads(lv6[v_ax0, v_ax1, v_k2])
- T.writes(A_red_temp_v0_shared[v_ax0, v_ax1],
A_red_temp_v1_shared[v_ax0, v_ax1])
+ v0 = T.axis.spatial(n, ax0_fused + ax0)
+ v1 = T.axis.reduce(T.int64(2560), ax1_fused_0 *
T.int64(256) + ax1_fused_1)
+ T.reads(lv6[T.int64(0), v0, v1])
+ T.writes(A_red_temp_v0_shared[T.int64(0), v0],
A_red_temp_v1_shared[T.int64(0), v0])
with T.init():
- A_red_temp_v0_shared[v_ax0, v_ax1] =
T.float32(0)
- A_red_temp_v1_shared[v_ax0, v_ax1] =
T.float32(0)
- v_A_red_temp_v0: T.float32 =
A_red_temp_v0_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2]
- v_A_red_temp_v1: T.float32 =
A_red_temp_v1_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0,
v_ax1, v_k2]
- A_red_temp_v0_shared[v_ax0, v_ax1] =
v_A_red_temp_v0
- A_red_temp_v1_shared[v_ax0, v_ax1] =
v_A_red_temp_v1
- for i2_0 in range(T.int64(10)):
- for i2_1 in T.thread_binding(T.int64(256),
thread="threadIdx.x"):
+ A_red_temp_v0_shared[T.int64(0), v0] =
T.float32(0)
+ A_red_temp_v1_shared[T.int64(0), v0] =
T.float32(0)
+ v_A_red_temp_v0: T.float32 =
A_red_temp_v0_shared[T.int64(0), v0] + lv6[T.int64(0), v0, v1]
+ v_A_red_temp_v1: T.float32 =
A_red_temp_v1_shared[T.int64(0), v0] + lv6[T.int64(0), v0, v1] *
lv6[T.int64(0), v0, v1]
+ A_red_temp_v0_shared[T.int64(0), v0] =
v_A_red_temp_v0
+ A_red_temp_v1_shared[T.int64(0), v0] =
v_A_red_temp_v1
+ for ax1_0 in range(T.int64(10)):
+ for ax1_1 in T.thread_binding(T.int64(256),
thread="threadIdx.x"):
with T.block("compute"):
- v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
- v_i1 = T.axis.spatial(n, i0_i1_fused)
- v_i2 = T.axis.spatial(T.int64(2560), i2_0 *
T.int64(256) + i2_1)
- T.reads(lv6[v_i0, v_i1, v_i2],
A_red_temp_v0_shared[v_i0, v_i1], A_red_temp_v1_shared[v_i0, v_i1],
weight1[v_i2], bias[v_i2])
- T.writes(var_compute_intermediate[v_i0, v_i1,
v_i2])
- var_compute_intermediate[v_i0, v_i1, v_i2] =
T.Cast("float16", (lv6[v_i0, v_i1, v_i2] - A_red_temp_v0_shared[v_i0, v_i1] *
T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1_shared[v_i0, v_i1] *
T.float32(0.00039062500000000002) - A_red_temp_v0_shared[v_i0, v_i1] *
T.float32(0.00039062500000000002) * (A_red_temp_v0_shared[v_i0, v_i1] *
T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) *
weight1[v_i2] + bias[v_i2])
+ v0 = T.axis.spatial(n, ax0_fused)
+ v1 = T.axis.spatial(T.int64(2560), ax1_0 *
T.int64(256) + ax1_1)
+ T.reads(lv6[T.int64(0), v0, v1],
A_red_temp_v0_shared[T.int64(0), v0], A_red_temp_v1_shared[T.int64(0), v0],
weight1[v1], bias[v1])
+ T.writes(var_compute_intermediate[T.int64(0), v0,
v1])
+ var_compute_intermediate[T.int64(0), v0, v1] =
T.Cast("float16", (lv6[T.int64(0), v0, v1] - A_red_temp_v0_shared[T.int64(0),
v0] * T.float32(0.00039062500000000002)) *
T.rsqrt(A_red_temp_v1_shared[T.int64(0), v0] *
T.float32(0.00039062500000000002) - A_red_temp_v0_shared[T.int64(0), v0] *
T.float32(0.00039062500000000002) * (A_red_temp_v0_shared[T.int64(0), v0] *
T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) *
weight1[v1] + bias[v1])
# fmt: on
_check(Before, After)
@@ -251,27 +246,25 @@ def test_rms_norm():
rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), n,
T.int64(4096)), "float16")
# with T.block("root"):
Ared_temp_shared = T.alloc_buffer((T.int64(1), n), scope="shared")
- for bsz_i_fused in T.thread_binding(n, thread="blockIdx.x",
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
- for ax0, ax1, ax2_fused_0 in T.grid(T.int64(1), T.int64(1),
T.int64(16)):
- for ax2_fused_1 in T.thread_binding(T.int64(256),
thread="threadIdx.x"):
+ for ax0_fused in T.thread_binding(n, thread="blockIdx.x",
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+ for ax0, ax1_fused_0 in T.grid(T.int64(1), T.int64(16)):
+ for ax1_fused_1 in T.thread_binding(T.int64(256),
thread="threadIdx.x"):
with T.block("Ared_temp"):
- v_bsz = T.axis.spatial(T.int64(1), ax0)
- v_i = T.axis.spatial(n, bsz_i_fused + ax1)
- v_k = T.axis.reduce(T.int64(4096), ax2_fused_0 *
T.int64(256) + ax2_fused_1)
- T.reads(A[v_bsz, v_i, v_k])
- T.writes(Ared_temp_shared[v_bsz, v_i])
+ v0 = T.axis.spatial(n, ax0_fused + ax0)
+ v1 = T.axis.reduce(T.int64(4096), ax1_fused_0 *
T.int64(256) + ax1_fused_1)
+ T.reads(A[T.int64(0), v0, v1])
+ T.writes(Ared_temp_shared[T.int64(0), v0])
with T.init():
- Ared_temp_shared[v_bsz, v_i] = T.float32(0)
- Ared_temp_shared[v_bsz, v_i] =
Ared_temp_shared[v_bsz, v_i] + T.Cast("float32", A[v_bsz, v_i, v_k]) *
T.Cast("float32", A[v_bsz, v_i, v_k])
- for k_0 in range(T.int64(16)):
- for k_1 in T.thread_binding(T.int64(256),
thread="threadIdx.x"):
+ Ared_temp_shared[T.int64(0), v0] = T.float32(0)
+ Ared_temp_shared[T.int64(0), v0] =
Ared_temp_shared[T.int64(0), v0] + T.Cast("float32", A[T.int64(0), v0, v1]) *
T.Cast("float32", A[T.int64(0), v0, v1])
+ for ax1_0 in range(T.int64(16)):
+ for ax1_1 in T.thread_binding(T.int64(256),
thread="threadIdx.x"):
with T.block("rms_norm"):
- v_bsz = T.axis.spatial(T.int64(1), T.int64(0))
- v_i = T.axis.spatial(n, bsz_i_fused)
- v_k = T.axis.spatial(T.int64(4096), k_0 *
T.int64(256) + k_1)
- T.reads(B[v_k], A[v_bsz, v_i, v_k],
Ared_temp_shared[v_bsz, v_i])
- T.writes(rms_norm_1[v_bsz, v_i, v_k])
- rms_norm_1[v_bsz, v_i, v_k] = T.Cast("float16",
T.Cast("float32", B[v_k]) * (T.Cast("float32", A[v_bsz, v_i, v_k]) /
T.sqrt(Ared_temp_shared[v_bsz, v_i] * T.float32(0.000244140625) +
T.float32(9.9999999999999995e-07))))
+ v0 = T.axis.spatial(n, ax0_fused)
+ v1 = T.axis.spatial(T.int64(4096), ax1_0 *
T.int64(256) + ax1_1)
+ T.reads(B[v1], A[T.int64(0), v0, v1],
Ared_temp_shared[T.int64(0), v0])
+ T.writes(rms_norm_1[T.int64(0), v0, v1])
+ rms_norm_1[T.int64(0), v0, v1] = T.Cast("float16",
T.Cast("float32", B[v1]) * (T.Cast("float32", A[T.int64(0), v0, v1]) /
T.sqrt(Ared_temp_shared[T.int64(0), v0] * T.float32(0.000244140625) +
T.float32(9.9999999999999995e-07))))
# fmt: on
_check(Before, After)
diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py
b/tests/python/unittest/test_tir_schedule_compute_inline.py
index 8d90189507..b9779c34cf 100644
--- a/tests/python/unittest/test_tir_schedule_compute_inline.py
+++ b/tests/python/unittest/test_tir_schedule_compute_inline.py
@@ -694,6 +694,23 @@ def elementwise_producer_not_cover_consumer(
D[vi, vj] = T.if_then_else(vi >= 128, B[vi - 128, vj],
T.float32(0), dtype="float32")
[email protected]_func
+def elementwise_producer_is_reduction(
+ A: T.Buffer((128, 128), "float32"), D: T.Buffer((128), "float32")
+) -> None:
+ B = T.alloc_buffer((128))
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SR", [i, j])
+ with T.init():
+ B[vi] = T.float32(0)
+ B[vi] = B[vi] + A[vi, vj]
+ for i in T.grid(128):
+ with T.block("C"):
+ vi = T.axis.remap("S", [i])
+ D[vi] = B[vi] + 1.0
+
+
@T.prim_func
def elementwise_predicate_producer(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
@@ -1267,6 +1284,13 @@ def
test_reverse_compute_inline_producer_predicate_disallowed():
)
+def test_reverse_compute_inline_producer_is_reduction():
+ """Test reverse comput inline when producer is reduction"""
+ sch = tir.Schedule(elementwise_producer_is_reduction, debug_mask="all")
+ with pytest.raises(tvm.tir.ScheduleError):
+ sch.reverse_compute_inline(sch.get_block("C"))
+
+
def test_compute_inline_softmax():
# fmt: off
@T.prim_func