This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 15a6b475bb [Unity][Dlight] Handle Epilogue Broadcasting (#15252)
15a6b475bb is described below
commit 15a6b475bbb667bf04fb298a1ea9e31ef55814c3
Author: Junru Shao <[email protected]>
AuthorDate: Thu Jul 6 09:02:31 2023 -0700
[Unity][Dlight] Handle Epilogue Broadcasting (#15252)
This PR improves the Decode-GEMV scheduling by further analyzing its
epilogue pattern.
The existing behavior assumes that the outcome of cross-thread reduction
stays in register files local to each thread, which is further used to
calculate the epilogue in the same thread.
This strategy means the cross-thread reduction outcome is stored only on
thread 0, while the other threads cannot participate in subsequent
computation (i.e. epilogue). Related:
https://github.com/apache/tvm/pull/15192.
When the epilogue is relatively lightweight, i.e. elementwise add,
casting on scalars, this strategy is optimal. However, once the outcome
needs to be broadcasted to compute over a non-trivial region, for
example, act as a normalizer of `np.mean`, it would become much slower
because only one thread in a thread block is effectively used.
In this case, we will need to broadcast the cross-thread reduction
outcome in shared memory, making it visible to other threads, and then
bind the compute region to all threads in the threadblock.
---
python/tvm/dlight/base/analysis.py | 25 +++++++-
python/tvm/dlight/gpu/decode_gemv.py | 94 ++++++++++++++++++++++-------
tests/python/dlight/test_gpu_decode_gemv.py | 24 ++++----
3 files changed, 107 insertions(+), 36 deletions(-)
diff --git a/python/tvm/dlight/base/analysis.py
b/python/tvm/dlight/base/analysis.py
index 2607968ef2..6e16239910 100644
--- a/python/tvm/dlight/base/analysis.py
+++ b/python/tvm/dlight/base/analysis.py
@@ -17,12 +17,13 @@
"""Analysis on TIR blocks, loops and functions."""
from typing import List, Optional, Union
-from tvm import tir
+from typing_extensions import Literal
+
+from tvm import ir, tir
from tvm._ffi import get_global_func
from tvm.target.target import Target
from tvm.tir import Schedule
from tvm.tir.schedule import BlockRV
-from typing_extensions import Literal
class IterInfo:
@@ -91,6 +92,26 @@ class BlockInfo:
"""Whether the block is injective, i.e. all its iteration domains are
injective."""
return all(k == "S" for k in self.dom_kind())
+ def is_elementwise(self, sch: tir.Schedule) -> bool:
+ """Whether the block is elementwise, i.e. trivial mapping between
read/write region"""
+
+ def _check_unit_var_range(dom: ir.Range, var: tir.Var) -> bool:
+ return dom.min.same_as(var) and dom.extent == 1
+
+ if not self.is_injective():
+ return False
+ block = sch.get(self.block_rv)
+ if len(block.reads) != 1 or len(block.writes) != 1:
+ return False
+ r_region = block.reads[0].region
+ w_region = block.writes[0].region
+ if len(r_region) != len(w_region):
+ return False
+ for var, r_dom, w_dom in zip(block.iter_vars, r_region, w_region):
+ if not _check_unit_var_range(var, r_dom) or not
_check_unit_var_range(var, w_dom):
+ return False
+ return True
+
def is_reduction(self) -> bool:
"""Whether the block is a reduction workload."""
# TODO(@junrushao): distinguish GEMV and reduction
diff --git a/python/tvm/dlight/gpu/decode_gemv.py
b/python/tvm/dlight/gpu/decode_gemv.py
index 69da1b0b0a..d0d37e8476 100644
--- a/python/tvm/dlight/gpu/decode_gemv.py
+++ b/python/tvm/dlight/gpu/decode_gemv.py
@@ -17,8 +17,7 @@
"""A rule for DecodeGEMV."""
from typing import List, Optional, Set, Tuple, Union
-from tvm import arith, tir
-from tvm.ir import structural_equal
+from tvm import arith, ir, tir
from tvm.target import Target
from ..base import (
@@ -37,7 +36,7 @@ def _get_reduction_expr(block: tir.Block) ->
Optional[tir.PrimExpr]:
return None
if not isinstance(buffer_store.value, tir.Add):
return None
- if not structural_equal(
+ if not ir.structural_equal(
buffer_store.value.a,
tir.BufferLoad(buffer_store.buffer, block.body.indices),
map_free_vars=True,
@@ -46,28 +45,48 @@ def _get_reduction_expr(block: tir.Block) ->
Optional[tir.PrimExpr]:
return buffer_store.value.b
-def _detect_dominant_read(block: tir.Block) -> tir.PrimExpr:
- dominant_read, read_iters = None, None
+def _collect_vars_used_in_access_region(region: List[ir.Range]) ->
Set[tir.Var]:
tir_vars: Set[tir.Var] = set()
- for buffer_region in block.reads:
- tir_vars.clear()
- def _collect_tir_var(expr):
- if isinstance(expr, tir.Var):
- tir_vars.add(expr)
+ def _collect_tir_var(expr):
+ if isinstance(expr, tir.Var):
+ tir_vars.add(expr)
+
+ for expr in region:
+ assert expr.extent == 1
+ tir.stmt_functor.post_order_visit(expr.min, _collect_tir_var)
+ return tir_vars
- for expr in buffer_region.region:
- assert expr.extent == 1
- tir.stmt_functor.post_order_visit(expr.min, _collect_tir_var)
- if read_iters is None or read_iters < len(tir_vars):
- read_iters = len(tir_vars)
+def _detect_dominant_read(block: tir.Block) -> tir.PrimExpr:
+ dominant_read = None
+ num_read_iters = -1
+ for buffer_region in block.reads:
+ tir_vars = _collect_vars_used_in_access_region(buffer_region.region)
+ if num_read_iters < len(tir_vars):
+ num_read_iters = len(tir_vars)
dominant_read = buffer_region
assert dominant_read is not None
(result,) = dominant_read.buffer.offset_of([e.min for e in
dominant_read.region])
return result
+def _is_broadcast_epilogue(
+ sch: tir.Schedule,
+ block: tir.schedule.BlockRV,
+ epilogue: tir.schedule.BlockRV,
+) -> bool:
+ write_buffers = {r.buffer for r in sch.get(block).writes}
+ epilogue_iters = {i.var: i for i in sch.get(epilogue).iter_vars if i.dom
!= 1}
+ for buffer_region in sch.get(epilogue).reads:
+ if buffer_region.buffer not in write_buffers:
+ continue
+ tir_vars = _collect_vars_used_in_access_region(buffer_region.region)
+ if len(tir_vars) < len(epilogue_iters):
+ return True
+ return False
+
+
class DecodeGEMV(ScheduleRule):
"""A rule for DecodeGEMV."""
@@ -81,8 +100,16 @@ class DecodeGEMV(ScheduleRule):
return None
sch = tir.Schedule(func)
block_infos = normalize_prim_func(sch)
+ if block_infos is None:
+ return None
block_infos = try_inline_contiguous_spatial(sch, block_infos)
- if block_infos is None or len(block_infos) > 2:
+ if len(block_infos) == 1:
+ epilogue = None
+ elif len(block_infos) == 2:
+ epilogue = block_infos[1]
+ if not epilogue.is_injective():
+ return None
+ else:
return None
block_info = block_infos[0]
@@ -109,13 +136,9 @@ class DecodeGEMV(ScheduleRule):
return None
# Step 3. Do the scheduling
if is_inner_reduction:
- self._sch_inner_reduction(sch, target, block, c_factor)
+ self._sch_inner_reduction(sch, target, block, c_factor, epilogue)
else:
- self._sch_inner_spatial(sch, target, block, c_factor)
- # Step 4. Schedule epilogue
- if len(block_infos) == 2:
- sch.set_scope(block, 0, "local")
- sch.reverse_compute_at(block_infos[1].block_rv,
sch.get_loops(block)[0])
+ self._sch_inner_spatial(sch, target, block, c_factor, epilogue)
return sch
def _normalize( # pylint: disable=too-many-branches
@@ -162,12 +185,13 @@ class DecodeGEMV(ScheduleRule):
sch.fuse(*r_loops)
return is_inner_reduction, c_factor
- def _sch_inner_reduction(
+ def _sch_inner_reduction( # pylint: disable=too-many-arguments
self,
sch: tir.Schedule,
target: Target,
block: tir.schedule.BlockRV,
unroll_spatial_factor: Optional[int],
+ epilogue_info: Optional[BlockInfo],
):
# pylint: disable=invalid-name
_, r, _ = sch.get_loops(block)
@@ -193,6 +217,17 @@ class DecodeGEMV(ScheduleRule):
s, inner = sch.split(s, factors=[None, unroll_spatial_factor])
sch.reorder(s, tx, inner)
sch.bind(tx, "threadIdx.x")
+ # Schedule epilogue
+ if epilogue_info is not None:
+ epilogue = epilogue_info.block_rv
+ sch.reverse_compute_at(epilogue, bx)
+ if _is_broadcast_epilogue(sch, block, epilogue):
+ sch.set_scope(block, 0, "shared")
+ _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name
+ _, tx = sch.split(sch.fuse(*s), factors=[None, len_tx])
+ sch.bind(tx, "threadIdx.x")
+ else:
+ sch.set_scope(block, 0, "local")
# pylint: enable=invalid-name
def _sch_inner_spatial(
@@ -201,6 +236,7 @@ class DecodeGEMV(ScheduleRule):
_: Target,
block: tir.schedule.BlockRV,
unroll_spatial_factor: Optional[int],
+ epilogue_info: Optional[BlockInfo],
):
# pylint: disable=invalid-name
s, r, _ = sch.get_loops(block)
@@ -226,4 +262,16 @@ class DecodeGEMV(ScheduleRule):
sch.reorder(s, r, inner)
sch.bind(s, "threadIdx.x")
sch.bind(r, "threadIdx.y")
+ # Schedule epilogue
+ if epilogue_info is not None:
+ epilogue = epilogue_info.block_rv
+ sch.reverse_compute_at(epilogue, bx)
+ if _is_broadcast_epilogue(sch, block, epilogue):
+ sch.set_scope(block, 0, "shared")
+ _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name
+ _, tx, ty = sch.split(sch.fuse(*s), factors=[None, len_tx,
len_ty])
+ sch.bind(tx, "threadIdx.x")
+ sch.bind(ty, "threadIdx.x")
+ else:
+ sch.set_scope(block, 0, "local")
# pylint: enable=invalid-name
diff --git a/tests/python/dlight/test_gpu_decode_gemv.py
b/tests/python/dlight/test_gpu_decode_gemv.py
index ba923a7d21..bd84aeb096 100644
--- a/tests/python/dlight/test_gpu_decode_gemv.py
+++ b/tests/python/dlight/test_gpu_decode_gemv.py
@@ -428,6 +428,7 @@ def test_reduction_no_spatial():
class Before:
@T.prim_func
def main(A: T.Buffer((1, 1, 4096), "float16"), B: T.Buffer((4096,),
"float16"), rms_norm: T.Buffer((1, 4096), "float16")):
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
Ared_temp = T.alloc_buffer((1, 1))
for ax0 in range(4096):
with T.block("Ared_temp"):
@@ -444,9 +445,9 @@ def test_reduction_no_spatial():
class After:
@T.prim_func
def main(A: T.Buffer((1, 1, 4096), "float16"), B: T.Buffer((4096,),
"float16"), rms_norm: T.Buffer((1, 4096), "float16")):
- T.func_attr({"tir.is_scheduled": 1})
+ T.func_attr({"global_symbol": "main", "tir.noalias": True,
"tir.is_scheduled": 1})
# with T.block("root"):
- Ared_temp_local = T.alloc_buffer((1, 1), scope="local")
+ Ared_temp_shared = T.alloc_buffer((1, 1), scope="shared")
Ared_temp_rf_local = T.alloc_buffer((256, 1, 1), scope="local")
for ax0_fused in T.thread_binding(T.int64(1),
thread="blockIdx.x"): # pylint: disable=unused-variable
for ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
@@ -470,16 +471,17 @@ def test_reduction_no_spatial():
vax1_fused_1 = T.axis.reduce(256, ax0)
v0 = T.axis.spatial(T.int64(1), T.int64(0))
T.reads(Ared_temp_rf_local[vax1_fused_1, 0, 0])
- T.writes(Ared_temp_local[0, 0])
+ T.writes(Ared_temp_shared[0, 0])
with T.init():
- Ared_temp_local[0, 0] = T.float32(0)
- Ared_temp_local[0, 0] = Ared_temp_local[0, 0] +
Ared_temp_rf_local[vax1_fused_1, 0, 0]
- for ax0 in range(4096):
- with T.block("rms_norm"):
- v0 = T.axis.spatial(4096, ax0)
- T.reads(B[v0], A[0, 0, v0], Ared_temp_local[0, 0])
- T.writes(rms_norm[0, v0])
- rms_norm[0, v0] = T.Cast("float16", T.Cast("float32",
B[v0]) * (T.Cast("float32", A[0, 0, v0]) / T.sqrt(Ared_temp_local[0, 0] *
T.float32(0.000244140625) + T.float32(9.9999999999999995e-07))))
+ Ared_temp_shared[0, 0] = T.float32(0)
+ Ared_temp_shared[0, 0] = Ared_temp_shared[0, 0] +
Ared_temp_rf_local[vax1_fused_1, 0, 0]
+ for ax0_fused_0 in range(16):
+ for ax0_fused_1 in T.thread_binding(256,
thread="threadIdx.x"):
+ with T.block("rms_norm"):
+ v0 = T.axis.spatial(4096, ax0_fused_0 * 256 +
ax0_fused_1)
+ T.reads(B[v0], A[0, 0, v0], Ared_temp_shared[0, 0])
+ T.writes(rms_norm[0, v0])
+ rms_norm[0, v0] = T.Cast("float16",
T.Cast("float32", B[v0]) * (T.Cast("float32", A[0, 0, v0]) /
T.sqrt(Ared_temp_shared[0, 0] * T.float32(0.000244140625) +
T.float32(9.9999999999999995e-07))))
# fmt: on
target = Target("nvidia/geforce-rtx-3090-ti")
with target: