This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new e1725ac01d [Unity][Dlight] Fallback rule supporting more spatial
workloads (#15687)
e1725ac01d is described below
commit e1725ac01d0cb9b2ef197fd9ab9b7feb5de2f440
Author: Ruihang Lai <[email protected]>
AuthorDate: Wed Sep 6 16:30:15 2023 -0400
[Unity][Dlight] Fallback rule supporting more spatial workloads (#15687)
This PR enhances the fallback dlight GPU rule so that it can support
more non-trivial spatial workloads.
Particularly, to this end, this PR makes the following changes:
* for function normalization: when a block of a TIR PrimFunc cannot be
applied by "TransformLayout", it will now give up layout transformation
instead of throwing the error in TransformLayout.
* for the fallback rule, if the return value of the normalization is
None, it will directly return, instead of proceeding the inline attempt.
* for the fallback rule, when iterating the blocks in the PrimFunc,
if the loops of a block are already bound to GPU threads, the block
will be skipped so the loops will not be processed again.
---
python/tvm/dlight/gpu/fallback.py | 13 +++++-
src/tir/schedule/transform.cc | 6 ++-
tests/python/dlight/test_gpu_fallback.py | 70 ++++++++++++++++++++++++++++++++
3 files changed, 87 insertions(+), 2 deletions(-)
diff --git a/python/tvm/dlight/gpu/fallback.py
b/python/tvm/dlight/gpu/fallback.py
index d209f88ec0..3e0dbbcdaa 100644
--- a/python/tvm/dlight/gpu/fallback.py
+++ b/python/tvm/dlight/gpu/fallback.py
@@ -40,7 +40,12 @@ class Fallback(ScheduleRule):
max_threads_per_block = utils.max_threads_per_block(target)
sch = tir.Schedule(func)
- block_infos = try_inline(sch, normalize_prim_func(sch))
+ block_infos = normalize_prim_func(sch)
+
+ if block_infos is None:
+ return None
+
+ block_infos = try_inline(sch, block_infos)
reduction_blocks: List[Tuple[tir.schedule.BlockRV,
tir.schedule.LoopRV]] = []
for block in block_infos:
s_loops: List[tir.schedule.LoopRV] = []
@@ -48,6 +53,12 @@ class Fallback(ScheduleRule):
o_loops: List[tir.schedule.LoopRV] = []
dom_kind = block.dom_kind()
block = block.block_rv
+
+ if any(
+ [sch.get(loop_rv).thread_binding is not None for loop_rv in
sch.get_loops(block)]
+ ):
+ continue
+
for loop, iter_type in zip(sch.get_loops(block), dom_kind):
{"S": s_loops, "R": r_loops, "O":
o_loops}[iter_type].append(loop)
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index 31c638f378..05fae3d102 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -472,7 +472,11 @@ Optional<ObjectRef> NormalizePrimFunc(Schedule sch) {
if (index_map_outputs.empty() || !has_spatial_iter) {
index_map_outputs.insert(index_map_outputs.begin(),
tir::make_const(DataType::Int(64), 0));
}
- sch->TransformBlockLayout(block, IndexMap(index_map_inputs,
index_map_outputs));
+ try {
+ sch->TransformBlockLayout(block, IndexMap(index_map_inputs,
index_map_outputs));
+ } catch (tvm::runtime::Error& e) {
+ // Skip layout transformation when not transformable.
+ }
block_loops.push_back(sch->GetLoops(block));
block_iters.push_back(sch->Get(block)->iter_vars);
bool is_reduction = IsReductionBlock(sch->state(), //
diff --git a/tests/python/dlight/test_gpu_fallback.py
b/tests/python/dlight/test_gpu_fallback.py
index d3fce0ee99..4457e627bd 100644
--- a/tests/python/dlight/test_gpu_fallback.py
+++ b/tests/python/dlight/test_gpu_fallback.py
@@ -109,5 +109,75 @@ def test_fallback_reduction():
assert_structural_equal(mod, Expected)
+def test_fallback_irregular_spatial():
+ @T.prim_func(private=True)
+ def func(
+ var_pages: T.handle,
+ var_page_table_indptr: T.handle,
+ var_page_table_values: T.handle,
+ var_values: T.handle,
+ seq_id: T.int32,
+ ):
+ nhead = T.int32()
+ nlayer = T.int32()
+ seqlen = T.int32()
+ npage = T.int32()
+ page_size = T.int32()
+ num_total_pages = T.int32()
+ num_total_seqs_plus_1 = T.int32()
+
+ pages = T.match_buffer(var_pages, (num_total_pages, nlayer, nhead,
page_size), "float16")
+ page_table_indptr = T.match_buffer(var_page_table_indptr,
(num_total_seqs_plus_1,), "int32")
+ page_table_values = T.match_buffer(var_page_table_values, (npage,),
"int32")
+ values = T.match_buffer(var_values, (nlayer, nhead, seqlen), "float16")
+
+ for l, h, pos in T.grid(nlayer, nhead, seqlen):
+ with T.block("block"):
+ vl, vh, vp = T.axis.remap("SSS", [l, h, pos])
+ values[vl, vh, vp] = pages[
+ page_table_values[page_table_indptr[seq_id] +
T.floordiv(vp, page_size)],
+ vl,
+ vh,
+ T.floormod(vp, page_size),
+ ]
+
+ # fmt: off
+ @T.prim_func(private=True)
+ def expected(var_pages: T.handle, var_page_table_indptr: T.handle,
var_page_table_values: T.handle, var_values: T.handle, seq_id: T.int32):
+ T.func_attr({"tir.is_scheduled": 1})
+ nhead = T.int32()
+ nlayer = T.int32()
+ seqlen = T.int32()
+ npage = T.int32()
+ page_size = T.int32()
+ num_total_pages = T.int32()
+ num_total_seqs_plus_1 = T.int32()
+
+ pages = T.match_buffer(var_pages, (num_total_pages, nlayer, nhead,
page_size), "float16")
+ page_table_indptr = T.match_buffer(var_page_table_indptr,
(num_total_seqs_plus_1,), "int32")
+ page_table_values = T.match_buffer(var_page_table_values, (npage,),
"int32")
+ values = T.match_buffer(var_values, (nlayer, nhead, seqlen), "float16")
+
+ for ax0_ax1_ax2_fused_0 in T.thread_binding((nlayer * nhead * seqlen +
1023) // 1024, thread="blockIdx.x"):
+ for ax0_ax1_ax2_fused_1 in T.thread_binding(1024,
thread="threadIdx.x"):
+ with T.block("block"):
+ v0 = T.axis.spatial(nlayer, (ax0_ax1_ax2_fused_0 * 1024 +
ax0_ax1_ax2_fused_1) % (seqlen * nhead * nlayer) // (seqlen * nhead))
+ v1 = T.axis.spatial(nhead, (ax0_ax1_ax2_fused_0 * 1024 +
ax0_ax1_ax2_fused_1) % (seqlen * nhead) // seqlen)
+ v2 = T.axis.spatial(seqlen, (ax0_ax1_ax2_fused_0 * 1024 +
ax0_ax1_ax2_fused_1) % seqlen)
+ T.where(ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 <
nlayer * nhead * seqlen)
+ T.reads(pages[page_table_values[page_table_indptr[seq_id]
+ v2 // page_size], v0, v1, v2 % page_size],
page_table_values[page_table_indptr[seq_id] + v2 // page_size],
page_table_indptr[seq_id])
+ T.writes(values[v0, v1, v2])
+ values[v0, v1, v2] =
pages[page_table_values[page_table_indptr[seq_id] + v2 // page_size], v0, v1,
v2 % page_size]
+ # fmt: on
+
+ target = Target("nvidia/geforce-rtx-3090-ti")
+ with target:
+ mod = tvm.IRModule({"main": func})
+ mod = dl.ApplyDefaultSchedule( # pylint: disable=not-callable
+ dl.gpu.Fallback(),
+ )(mod)
+ assert_structural_equal(mod["main"], expected)
+
+
if __name__ == "__main__":
tvm.testing.main()