This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 15a6b475bb [Unity][Dlight] Handle Epilogue Broadcasting (#15252)
15a6b475bb is described below

commit 15a6b475bbb667bf04fb298a1ea9e31ef55814c3
Author: Junru Shao <[email protected]>
AuthorDate: Thu Jul 6 09:02:31 2023 -0700

    [Unity][Dlight] Handle Epilogue Broadcasting (#15252)
    
    This PR improves the Decode-GEMV scheduling by further analyzing its
    epilogue pattern.
    
    The existing behavior assumes that the outcome of cross-thread reduction
    stays in register files local to each thread, which is further used to
    calculate the epilogue in the same thread.
    
    This strategy means the cross-thread reduction outcome is stored only on
    thread 0, while the other threads cannot participate in subsequent
    computation (i.e. epilogue). Related: 
https://github.com/apache/tvm/pull/15192.
    
    When the epilogue is relatively lightweight, i.e. elementwise add,
    casting on scalars, this strategy is optimal. However, once the outcome
    needs to be broadcasted to compute over a non-trivial region, for
    example, act as a normalizer of `np.mean`, it would become much slower
    because only one thread in a thread block is effectively used.
    
    In this case, we will need to broadcast the cross-thread reduction
    outcome in shared memory, making it visible to other threads, and then
    bind the compute region to all threads in the threadblock.
---
 python/tvm/dlight/base/analysis.py          | 25 +++++++-
 python/tvm/dlight/gpu/decode_gemv.py        | 94 ++++++++++++++++++++++-------
 tests/python/dlight/test_gpu_decode_gemv.py | 24 ++++----
 3 files changed, 107 insertions(+), 36 deletions(-)

diff --git a/python/tvm/dlight/base/analysis.py 
b/python/tvm/dlight/base/analysis.py
index 2607968ef2..6e16239910 100644
--- a/python/tvm/dlight/base/analysis.py
+++ b/python/tvm/dlight/base/analysis.py
@@ -17,12 +17,13 @@
 """Analysis on TIR blocks, loops and functions."""
 from typing import List, Optional, Union
 
-from tvm import tir
+from typing_extensions import Literal
+
+from tvm import ir, tir
 from tvm._ffi import get_global_func
 from tvm.target.target import Target
 from tvm.tir import Schedule
 from tvm.tir.schedule import BlockRV
-from typing_extensions import Literal
 
 
 class IterInfo:
@@ -91,6 +92,26 @@ class BlockInfo:
         """Whether the block is injective, i.e. all its iteration domains are 
injective."""
         return all(k == "S" for k in self.dom_kind())
 
+    def is_elementwise(self, sch: tir.Schedule) -> bool:
+        """Whether the block is elementwise, i.e. trivial mapping between 
read/write region"""
+
+        def _check_unit_var_range(dom: ir.Range, var: tir.Var) -> bool:
+            return dom.min.same_as(var) and dom.extent == 1
+
+        if not self.is_injective():
+            return False
+        block = sch.get(self.block_rv)
+        if len(block.reads) != 1 or len(block.writes) != 1:
+            return False
+        r_region = block.reads[0].region
+        w_region = block.writes[0].region
+        if len(r_region) != len(w_region):
+            return False
+        for var, r_dom, w_dom in zip(block.iter_vars, r_region, w_region):
+            if not _check_unit_var_range(var, r_dom) or not 
_check_unit_var_range(var, w_dom):
+                return False
+        return True
+
     def is_reduction(self) -> bool:
         """Whether the block is a reduction workload."""
         # TODO(@junrushao): distinguish GEMV and reduction
diff --git a/python/tvm/dlight/gpu/decode_gemv.py 
b/python/tvm/dlight/gpu/decode_gemv.py
index 69da1b0b0a..d0d37e8476 100644
--- a/python/tvm/dlight/gpu/decode_gemv.py
+++ b/python/tvm/dlight/gpu/decode_gemv.py
@@ -17,8 +17,7 @@
 """A rule for DecodeGEMV."""
 from typing import List, Optional, Set, Tuple, Union
 
-from tvm import arith, tir
-from tvm.ir import structural_equal
+from tvm import arith, ir, tir
 from tvm.target import Target
 
 from ..base import (
@@ -37,7 +36,7 @@ def _get_reduction_expr(block: tir.Block) -> 
Optional[tir.PrimExpr]:
         return None
     if not isinstance(buffer_store.value, tir.Add):
         return None
-    if not structural_equal(
+    if not ir.structural_equal(
         buffer_store.value.a,
         tir.BufferLoad(buffer_store.buffer, block.body.indices),
         map_free_vars=True,
@@ -46,28 +45,48 @@ def _get_reduction_expr(block: tir.Block) -> 
Optional[tir.PrimExpr]:
     return buffer_store.value.b
 
 
-def _detect_dominant_read(block: tir.Block) -> tir.PrimExpr:
-    dominant_read, read_iters = None, None
+def _collect_vars_used_in_access_region(region: List[ir.Range]) -> 
Set[tir.Var]:
     tir_vars: Set[tir.Var] = set()
-    for buffer_region in block.reads:
-        tir_vars.clear()
 
-        def _collect_tir_var(expr):
-            if isinstance(expr, tir.Var):
-                tir_vars.add(expr)
+    def _collect_tir_var(expr):
+        if isinstance(expr, tir.Var):
+            tir_vars.add(expr)
+
+    for expr in region:
+        assert expr.extent == 1
+        tir.stmt_functor.post_order_visit(expr.min, _collect_tir_var)
+    return tir_vars
 
-        for expr in buffer_region.region:
-            assert expr.extent == 1
-            tir.stmt_functor.post_order_visit(expr.min, _collect_tir_var)
 
-        if read_iters is None or read_iters < len(tir_vars):
-            read_iters = len(tir_vars)
+def _detect_dominant_read(block: tir.Block) -> tir.PrimExpr:
+    dominant_read = None
+    num_read_iters = -1
+    for buffer_region in block.reads:
+        tir_vars = _collect_vars_used_in_access_region(buffer_region.region)
+        if num_read_iters < len(tir_vars):
+            num_read_iters = len(tir_vars)
             dominant_read = buffer_region
     assert dominant_read is not None
     (result,) = dominant_read.buffer.offset_of([e.min for e in 
dominant_read.region])
     return result
 
 
+def _is_broadcast_epilogue(
+    sch: tir.Schedule,
+    block: tir.schedule.BlockRV,
+    epilogue: tir.schedule.BlockRV,
+) -> bool:
+    write_buffers = {r.buffer for r in sch.get(block).writes}
+    epilogue_iters = {i.var: i for i in sch.get(epilogue).iter_vars if i.dom 
!= 1}
+    for buffer_region in sch.get(epilogue).reads:
+        if buffer_region.buffer not in write_buffers:
+            continue
+        tir_vars = _collect_vars_used_in_access_region(buffer_region.region)
+        if len(tir_vars) < len(epilogue_iters):
+            return True
+    return False
+
+
 class DecodeGEMV(ScheduleRule):
     """A rule for DecodeGEMV."""
 
@@ -81,8 +100,16 @@ class DecodeGEMV(ScheduleRule):
             return None
         sch = tir.Schedule(func)
         block_infos = normalize_prim_func(sch)
+        if block_infos is None:
+            return None
         block_infos = try_inline_contiguous_spatial(sch, block_infos)
-        if block_infos is None or len(block_infos) > 2:
+        if len(block_infos) == 1:
+            epilogue = None
+        elif len(block_infos) == 2:
+            epilogue = block_infos[1]
+            if not epilogue.is_injective():
+                return None
+        else:
             return None
 
         block_info = block_infos[0]
@@ -109,13 +136,9 @@ class DecodeGEMV(ScheduleRule):
             return None
         # Step 3. Do the scheduling
         if is_inner_reduction:
-            self._sch_inner_reduction(sch, target, block, c_factor)
+            self._sch_inner_reduction(sch, target, block, c_factor, epilogue)
         else:
-            self._sch_inner_spatial(sch, target, block, c_factor)
-        # Step 4. Schedule epilogue
-        if len(block_infos) == 2:
-            sch.set_scope(block, 0, "local")
-            sch.reverse_compute_at(block_infos[1].block_rv, 
sch.get_loops(block)[0])
+            self._sch_inner_spatial(sch, target, block, c_factor, epilogue)
         return sch
 
     def _normalize(  # pylint: disable=too-many-branches
@@ -162,12 +185,13 @@ class DecodeGEMV(ScheduleRule):
         sch.fuse(*r_loops)
         return is_inner_reduction, c_factor
 
-    def _sch_inner_reduction(
+    def _sch_inner_reduction(  # pylint: disable=too-many-arguments
         self,
         sch: tir.Schedule,
         target: Target,
         block: tir.schedule.BlockRV,
         unroll_spatial_factor: Optional[int],
+        epilogue_info: Optional[BlockInfo],
     ):
         # pylint: disable=invalid-name
         _, r, _ = sch.get_loops(block)
@@ -193,6 +217,17 @@ class DecodeGEMV(ScheduleRule):
             s, inner = sch.split(s, factors=[None, unroll_spatial_factor])
             sch.reorder(s, tx, inner)
         sch.bind(tx, "threadIdx.x")
+        # Schedule epilogue
+        if epilogue_info is not None:
+            epilogue = epilogue_info.block_rv
+            sch.reverse_compute_at(epilogue, bx)
+            if _is_broadcast_epilogue(sch, block, epilogue):
+                sch.set_scope(block, 0, "shared")
+                _, *s = sch.get_loops(epilogue)  # pylint: disable=invalid-name
+                _, tx = sch.split(sch.fuse(*s), factors=[None, len_tx])
+                sch.bind(tx, "threadIdx.x")
+            else:
+                sch.set_scope(block, 0, "local")
         # pylint: enable=invalid-name
 
     def _sch_inner_spatial(
@@ -201,6 +236,7 @@ class DecodeGEMV(ScheduleRule):
         _: Target,
         block: tir.schedule.BlockRV,
         unroll_spatial_factor: Optional[int],
+        epilogue_info: Optional[BlockInfo],
     ):
         # pylint: disable=invalid-name
         s, r, _ = sch.get_loops(block)
@@ -226,4 +262,16 @@ class DecodeGEMV(ScheduleRule):
             sch.reorder(s, r, inner)
         sch.bind(s, "threadIdx.x")
         sch.bind(r, "threadIdx.y")
+        # Schedule epilogue
+        if epilogue_info is not None:
+            epilogue = epilogue_info.block_rv
+            sch.reverse_compute_at(epilogue, bx)
+            if _is_broadcast_epilogue(sch, block, epilogue):
+                sch.set_scope(block, 0, "shared")
+                _, *s = sch.get_loops(epilogue)  # pylint: disable=invalid-name
+                _, tx, ty = sch.split(sch.fuse(*s), factors=[None, len_tx, 
len_ty])
+                sch.bind(tx, "threadIdx.x")
+                sch.bind(ty, "threadIdx.x")
+            else:
+                sch.set_scope(block, 0, "local")
         # pylint: enable=invalid-name
diff --git a/tests/python/dlight/test_gpu_decode_gemv.py 
b/tests/python/dlight/test_gpu_decode_gemv.py
index ba923a7d21..bd84aeb096 100644
--- a/tests/python/dlight/test_gpu_decode_gemv.py
+++ b/tests/python/dlight/test_gpu_decode_gemv.py
@@ -428,6 +428,7 @@ def test_reduction_no_spatial():
     class Before:
         @T.prim_func
         def main(A: T.Buffer((1, 1, 4096), "float16"), B: T.Buffer((4096,), 
"float16"), rms_norm: T.Buffer((1, 4096), "float16")):
+            T.func_attr({"global_symbol": "main", "tir.noalias": True})
             Ared_temp = T.alloc_buffer((1, 1))
             for ax0 in range(4096):
                 with T.block("Ared_temp"):
@@ -444,9 +445,9 @@ def test_reduction_no_spatial():
     class After:
         @T.prim_func
         def main(A: T.Buffer((1, 1, 4096), "float16"), B: T.Buffer((4096,), 
"float16"), rms_norm: T.Buffer((1, 4096), "float16")):
-            T.func_attr({"tir.is_scheduled": 1})
+            T.func_attr({"global_symbol": "main", "tir.noalias": True, 
"tir.is_scheduled": 1})
             # with T.block("root"):
-            Ared_temp_local = T.alloc_buffer((1, 1), scope="local")
+            Ared_temp_shared = T.alloc_buffer((1, 1), scope="shared")
             Ared_temp_rf_local = T.alloc_buffer((256, 1, 1), scope="local")
             for ax0_fused in T.thread_binding(T.int64(1), 
thread="blockIdx.x"): # pylint: disable=unused-variable
                 for ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
@@ -470,16 +471,17 @@ def test_reduction_no_spatial():
                             vax1_fused_1 = T.axis.reduce(256, ax0)
                             v0 = T.axis.spatial(T.int64(1), T.int64(0))
                             T.reads(Ared_temp_rf_local[vax1_fused_1, 0, 0])
-                            T.writes(Ared_temp_local[0, 0])
+                            T.writes(Ared_temp_shared[0, 0])
                             with T.init():
-                                Ared_temp_local[0, 0] = T.float32(0)
-                            Ared_temp_local[0, 0] = Ared_temp_local[0, 0] + 
Ared_temp_rf_local[vax1_fused_1, 0, 0]
-                for ax0 in range(4096):
-                    with T.block("rms_norm"):
-                        v0 = T.axis.spatial(4096, ax0)
-                        T.reads(B[v0], A[0, 0, v0], Ared_temp_local[0, 0])
-                        T.writes(rms_norm[0, v0])
-                        rms_norm[0, v0] = T.Cast("float16", T.Cast("float32", 
B[v0]) * (T.Cast("float32", A[0, 0, v0]) / T.sqrt(Ared_temp_local[0, 0] * 
T.float32(0.000244140625) + T.float32(9.9999999999999995e-07))))
+                                Ared_temp_shared[0, 0] = T.float32(0)
+                            Ared_temp_shared[0, 0] = Ared_temp_shared[0, 0] + 
Ared_temp_rf_local[vax1_fused_1, 0, 0]
+                for ax0_fused_0 in range(16):
+                    for ax0_fused_1 in T.thread_binding(256, 
thread="threadIdx.x"):
+                        with T.block("rms_norm"):
+                            v0 = T.axis.spatial(4096, ax0_fused_0 * 256 + 
ax0_fused_1)
+                            T.reads(B[v0], A[0, 0, v0], Ared_temp_shared[0, 0])
+                            T.writes(rms_norm[0, v0])
+                            rms_norm[0, v0] = T.Cast("float16", 
T.Cast("float32", B[v0]) * (T.Cast("float32", A[0, 0, v0]) / 
T.sqrt(Ared_temp_shared[0, 0] * T.float32(0.000244140625) + 
T.float32(9.9999999999999995e-07))))
     # fmt: on
     target = Target("nvidia/geforce-rtx-3090-ti")
     with target:

Reply via email to