This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new fae8308813 [Dlight] Enhance fallback schedule with DecomposeReduction 
(#15302)
fae8308813 is described below

commit fae83088130cf22ce7511dbbfad392b24e83bed7
Author: Ruihang Lai <[email protected]>
AuthorDate: Wed Jul 12 11:23:57 2023 -0700

    [Dlight] Enhance fallback schedule with DecomposeReduction (#15302)
    
    The current fallback dlight schedule does not decompose the reduction
    init blocks, which might lead to some correctness issue (observed on
    Metal while not CUDA).
    
    Doing DecomposeReduction effectively resolve the issue and meanwhile
    provide (minor) performance improvement.
---
 python/tvm/dlight/gpu/fallback.py        | 10 ++++++-
 tests/python/dlight/test_gpu_fallback.py | 47 +++++++++++++++++++++++++++++++-
 2 files changed, 55 insertions(+), 2 deletions(-)

diff --git a/python/tvm/dlight/gpu/fallback.py 
b/python/tvm/dlight/gpu/fallback.py
index 14b74887af..d209f88ec0 100644
--- a/python/tvm/dlight/gpu/fallback.py
+++ b/python/tvm/dlight/gpu/fallback.py
@@ -16,7 +16,7 @@
 # under the License.
 # pylint: disable=missing-docstring
 """A fallback schedule rule for GPU operators."""
-from typing import List
+from typing import List, Tuple
 
 from tvm import tir
 from tvm.target import Target
@@ -41,6 +41,7 @@ class Fallback(ScheduleRule):
 
         sch = tir.Schedule(func)
         block_infos = try_inline(sch, normalize_prim_func(sch))
+        reduction_blocks: List[Tuple[tir.schedule.BlockRV, 
tir.schedule.LoopRV]] = []
         for block in block_infos:
             s_loops: List[tir.schedule.LoopRV] = []
             r_loops: List[tir.schedule.LoopRV] = []
@@ -59,4 +60,11 @@ class Fallback(ScheduleRule):
             )
             sch.bind(bx, "blockIdx.x")
             sch.bind(tx, "threadIdx.x")
+
+            if len(r_loops) > 0:
+                reduction_blocks.append((block, r_loops[0]))
+
+        for block, r_loop in reduction_blocks:
+            sch.decompose_reduction(block, r_loop)
+
         return sch
diff --git a/tests/python/dlight/test_gpu_fallback.py 
b/tests/python/dlight/test_gpu_fallback.py
index 38e9a391dc..d3fce0ee99 100644
--- a/tests/python/dlight/test_gpu_fallback.py
+++ b/tests/python/dlight/test_gpu_fallback.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=missing-docstring
+import tvm.testing
 from tvm import dlight as dl
 from tvm.ir import assert_structural_equal
 from tvm.script import ir as I
@@ -64,5 +65,49 @@ def test_fallback():
     assert_structural_equal(mod, After)
 
 
+def test_fallback_reduction():
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def main(A: T.Buffer((1, 6144), "float32"), B: T.Buffer((1,), 
"float32")):
+            for ax0, ax1 in T.grid(1, 6144):
+                with T.block("block"):
+                    v0 = T.axis.spatial(1, ax0)
+                    v1 = T.axis.reduce(6144, ax1)
+                    T.reads(A[v0, v1])
+                    T.writes(B[v0])
+                    with T.init():
+                        B[v0] = T.float32(0)
+                    B[v0] = B[v0] + T.Cast("float32", A[v0, v1])
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def main(A: T.Buffer((1, 6144), "float32"), B: T.Buffer((1,), 
"float32")):
+            T.func_attr({"tir.is_scheduled": 1})
+            for ax0_fused_0 in T.thread_binding(T.int64(1), 
thread="blockIdx.x"):
+                for ax0_fused_1 in T.thread_binding(T.int64(1024), 
thread="threadIdx.x"):
+                    with T.block("block_init"):
+                        v0 = T.axis.spatial(T.int64(1), T.int64(0))
+                        T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 < 
T.int64(1))
+                        T.reads()
+                        T.writes(B[0])
+                        B[0] = T.float32(0)
+                    for ax1 in range(6144):
+                        with T.block("block_update"):
+                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
+                            v1 = T.axis.reduce(6144, ax1)
+                            T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 
< T.int64(1))
+                            T.reads(B[0], A[0, v1])
+                            T.writes(B[0])
+                            B[0] = B[0] + T.Cast("float32", A[0, v1])
+
+    with Target("apple/m1-gpu"):
+        mod = dl.ApplyDefaultSchedule(  # pylint: disable=not-callable
+            dl.gpu.Fallback(),
+        )(Module)
+    assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
-    test_fallback()
+    tvm.testing.main()

Reply via email to