This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new e1725ac01d [Unity][Dlight] Fallback rule supporting more spatial 
workloads (#15687)
e1725ac01d is described below

commit e1725ac01d0cb9b2ef197fd9ab9b7feb5de2f440
Author: Ruihang Lai <[email protected]>
AuthorDate: Wed Sep 6 16:30:15 2023 -0400

    [Unity][Dlight] Fallback rule supporting more spatial workloads (#15687)
    
    This PR enhances the fallback dlight GPU rule so that it can support
    more non-trivial spatial workloads.
    
    Particularly, to this end, this PR makes the following changes:
    
    * for function normalization: when a block of a TIR PrimFunc cannot be
    applied by "TransformLayout", it will now give up layout transformation
    instead of throwing the error in TransformLayout.
    * for the fallback rule, if the return value of the normalization is
    None, it will directly return, instead of proceeding the inline attempt.
    * for the fallback rule, when iterating the blocks in the PrimFunc,
    if the loops of a block are already bound to GPU threads, the block
    will be skipped so the loops will not be processed again.
---
 python/tvm/dlight/gpu/fallback.py        | 13 +++++-
 src/tir/schedule/transform.cc            |  6 ++-
 tests/python/dlight/test_gpu_fallback.py | 70 ++++++++++++++++++++++++++++++++
 3 files changed, 87 insertions(+), 2 deletions(-)

diff --git a/python/tvm/dlight/gpu/fallback.py 
b/python/tvm/dlight/gpu/fallback.py
index d209f88ec0..3e0dbbcdaa 100644
--- a/python/tvm/dlight/gpu/fallback.py
+++ b/python/tvm/dlight/gpu/fallback.py
@@ -40,7 +40,12 @@ class Fallback(ScheduleRule):
         max_threads_per_block = utils.max_threads_per_block(target)
 
         sch = tir.Schedule(func)
-        block_infos = try_inline(sch, normalize_prim_func(sch))
+        block_infos = normalize_prim_func(sch)
+
+        if block_infos is None:
+            return None
+
+        block_infos = try_inline(sch, block_infos)
         reduction_blocks: List[Tuple[tir.schedule.BlockRV, 
tir.schedule.LoopRV]] = []
         for block in block_infos:
             s_loops: List[tir.schedule.LoopRV] = []
@@ -48,6 +53,12 @@ class Fallback(ScheduleRule):
             o_loops: List[tir.schedule.LoopRV] = []
             dom_kind = block.dom_kind()
             block = block.block_rv
+
+            if any(
+                [sch.get(loop_rv).thread_binding is not None for loop_rv in 
sch.get_loops(block)]
+            ):
+                continue
+
             for loop, iter_type in zip(sch.get_loops(block), dom_kind):
                 {"S": s_loops, "R": r_loops, "O": 
o_loops}[iter_type].append(loop)
 
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index 31c638f378..05fae3d102 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -472,7 +472,11 @@ Optional<ObjectRef> NormalizePrimFunc(Schedule sch) {
     if (index_map_outputs.empty() || !has_spatial_iter) {
       index_map_outputs.insert(index_map_outputs.begin(), 
tir::make_const(DataType::Int(64), 0));
     }
-    sch->TransformBlockLayout(block, IndexMap(index_map_inputs, 
index_map_outputs));
+    try {
+      sch->TransformBlockLayout(block, IndexMap(index_map_inputs, 
index_map_outputs));
+    } catch (tvm::runtime::Error& e) {
+      // Skip layout transformation when not transformable.
+    }
     block_loops.push_back(sch->GetLoops(block));
     block_iters.push_back(sch->Get(block)->iter_vars);
     bool is_reduction = IsReductionBlock(sch->state(),         //
diff --git a/tests/python/dlight/test_gpu_fallback.py 
b/tests/python/dlight/test_gpu_fallback.py
index d3fce0ee99..4457e627bd 100644
--- a/tests/python/dlight/test_gpu_fallback.py
+++ b/tests/python/dlight/test_gpu_fallback.py
@@ -109,5 +109,75 @@ def test_fallback_reduction():
     assert_structural_equal(mod, Expected)
 
 
+def test_fallback_irregular_spatial():
+    @T.prim_func(private=True)
+    def func(
+        var_pages: T.handle,
+        var_page_table_indptr: T.handle,
+        var_page_table_values: T.handle,
+        var_values: T.handle,
+        seq_id: T.int32,
+    ):
+        nhead = T.int32()
+        nlayer = T.int32()
+        seqlen = T.int32()
+        npage = T.int32()
+        page_size = T.int32()
+        num_total_pages = T.int32()
+        num_total_seqs_plus_1 = T.int32()
+
+        pages = T.match_buffer(var_pages, (num_total_pages, nlayer, nhead, 
page_size), "float16")
+        page_table_indptr = T.match_buffer(var_page_table_indptr, 
(num_total_seqs_plus_1,), "int32")
+        page_table_values = T.match_buffer(var_page_table_values, (npage,), 
"int32")
+        values = T.match_buffer(var_values, (nlayer, nhead, seqlen), "float16")
+
+        for l, h, pos in T.grid(nlayer, nhead, seqlen):
+            with T.block("block"):
+                vl, vh, vp = T.axis.remap("SSS", [l, h, pos])
+                values[vl, vh, vp] = pages[
+                    page_table_values[page_table_indptr[seq_id] + 
T.floordiv(vp, page_size)],
+                    vl,
+                    vh,
+                    T.floormod(vp, page_size),
+                ]
+
+    # fmt: off
+    @T.prim_func(private=True)
+    def expected(var_pages: T.handle, var_page_table_indptr: T.handle, 
var_page_table_values: T.handle, var_values: T.handle, seq_id: T.int32):
+        T.func_attr({"tir.is_scheduled": 1})
+        nhead = T.int32()
+        nlayer = T.int32()
+        seqlen = T.int32()
+        npage = T.int32()
+        page_size = T.int32()
+        num_total_pages = T.int32()
+        num_total_seqs_plus_1 = T.int32()
+
+        pages = T.match_buffer(var_pages, (num_total_pages, nlayer, nhead, 
page_size), "float16")
+        page_table_indptr = T.match_buffer(var_page_table_indptr, 
(num_total_seqs_plus_1,), "int32")
+        page_table_values = T.match_buffer(var_page_table_values, (npage,), 
"int32")
+        values = T.match_buffer(var_values, (nlayer, nhead, seqlen), "float16")
+
+        for ax0_ax1_ax2_fused_0 in T.thread_binding((nlayer * nhead * seqlen + 
1023) // 1024, thread="blockIdx.x"):
+            for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, 
thread="threadIdx.x"):
+                with T.block("block"):
+                    v0 = T.axis.spatial(nlayer, (ax0_ax1_ax2_fused_0 * 1024 + 
ax0_ax1_ax2_fused_1) % (seqlen * nhead * nlayer) // (seqlen * nhead))
+                    v1 = T.axis.spatial(nhead, (ax0_ax1_ax2_fused_0 * 1024 + 
ax0_ax1_ax2_fused_1) % (seqlen * nhead) // seqlen)
+                    v2 = T.axis.spatial(seqlen, (ax0_ax1_ax2_fused_0 * 1024 + 
ax0_ax1_ax2_fused_1) % seqlen)
+                    T.where(ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 < 
nlayer * nhead * seqlen)
+                    T.reads(pages[page_table_values[page_table_indptr[seq_id] 
+ v2 // page_size], v0, v1, v2 % page_size], 
page_table_values[page_table_indptr[seq_id] + v2 // page_size], 
page_table_indptr[seq_id])
+                    T.writes(values[v0, v1, v2])
+                    values[v0, v1, v2] = 
pages[page_table_values[page_table_indptr[seq_id] + v2 // page_size], v0, v1, 
v2 % page_size]
+    # fmt: on
+
+    target = Target("nvidia/geforce-rtx-3090-ti")
+    with target:
+        mod = tvm.IRModule({"main": func})
+        mod = dl.ApplyDefaultSchedule(  # pylint: disable=not-callable
+            dl.gpu.Fallback(),
+        )(mod)
+    assert_structural_equal(mod["main"], expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to