This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d1871a6957 [MetaSchedule] Relax conditions of rule Cross-Thread 
Reduction (#12825)
d1871a6957 is described below

commit d1871a6957b4f469f1b994aa6c89e0d209b64f05
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Sep 17 22:03:17 2022 -0400

    [MetaSchedule] Relax conditions of rule Cross-Thread Reduction (#12825)
    
    This PR relaxes the conditions of Meta-Schedule schedule rule 
CrossThreadReduction. The rules are previously a bit over-strict, and some 
workloads with small reduction loop length are unable to be optimized by 
cross-thread reduction automatically. In this PR, we relax the rules so that 
such workloads can be optimized.
---
 src/tir/schedule/analysis/analysis.cc              |  6 +-
 ...chedule_schedule_rule_cross_thread_reduction.py | 98 ++++++++++++++++++++++
 2 files changed, 100 insertions(+), 4 deletions(-)

diff --git a/src/tir/schedule/analysis/analysis.cc 
b/src/tir/schedule/analysis/analysis.cc
index 4f78b0c9cd..e39f7b2554 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -1640,11 +1640,9 @@ bool NeedsRFactorOrCrossThreadReduction(const 
tir::ScheduleState& self,   //
   if (NeedsMultiLevelTiling(self, block_sref)) {
     // Do not use rfactor/cross-thread-reduction if we have enough parallelism 
on spatial loops.
     return !(cum_space_len >= cum_reduce_len || cum_space_len > 
max_parallel_extent);
-  } else if (cum_reduce_len > 1) {
-    // Always try rfactor/cross-thread-reduction for other reduction blocks.
-    return cum_reduce_len > max_parallel_basic;
   } else {
-    return false;
+    // Always try rfactor/cross-thread-reduction for other reduction blocks.
+    return cum_reduce_len > 1;
   }
 }
 
diff --git 
a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
 
b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
index 4278638a1a..718b264bdd 100644
--- 
a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
+++ 
b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
@@ -589,6 +589,28 @@ def argmax(
             argmax_v1[i] = v_argmax_v1
 
 
[email protected]_func
+def argmax_32(
+    idx: T.Buffer[(1, 32), "int32"],
+    val: T.Buffer[(1, 32), "float32"],
+    argmax_v0: T.Buffer[(1,), "int32"],
+    argmax_v1: T.Buffer[(1,), "float32"],
+) -> None:
+    for i0, i1 in T.grid(1, 32):
+        with T.block("argmax"):
+            i = T.axis.spatial(1, i0)
+            k = T.axis.reduce(32, i1)
+            T.reads(idx[i, k], val[i, k])
+            T.writes(argmax_v0[i], argmax_v1[i])
+            with T.init():
+                argmax_v0[i] = -1
+                argmax_v1[i] = T.min_value("float32")
+            v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], 
argmax_v0[i], idx[i, k])
+            v_argmax_v1: T.float32 = T.Select(argmax_v1[i] >= val[i, k], 
argmax_v1[i], val[i, k])
+            argmax_v0[i] = v_argmax_v0
+            argmax_v1[i] = v_argmax_v1
+
+
 def test_gpu_argmax():
     @T.prim_func
     def argmax_0(
@@ -663,8 +685,84 @@ def test_gpu_argmax():
     )
 
 
+def test_gpu_argmax_32():
+    @T.prim_func
+    def argmax_0(
+        idx: T.Buffer[(1, 32), "int32"],
+        val: T.Buffer[(1, 32), "float32"],
+        argmax_v0: T.Buffer[(1,), "int32"],
+        argmax_v1: T.Buffer[(1,), "float32"],
+    ) -> None:
+        # body
+        # with T.block("root")
+        for i0, i1 in T.grid(1, 32):
+            with T.block("argmax"):
+                i, k = T.axis.remap("SR", [i0, i1])
+                T.reads(idx[i, k], val[i, k])
+                T.writes(argmax_v0[i], argmax_v1[i])
+                with T.init():
+                    argmax_v0[i] = -1
+                    argmax_v1[i] = T.float32(-3.4028234663852886e38)
+                v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], 
argmax_v0[i], idx[i, k])
+                v_argmax_v1: T.float32 = T.Select(
+                    argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]
+                )
+                argmax_v0[i] = v_argmax_v0
+                argmax_v1[i] = v_argmax_v1
+
+    @T.prim_func
+    def argmax_1(
+        idx: T.Buffer[(1, 32), "int32"],
+        val: T.Buffer[(1, 32), "float32"],
+        argmax_v0: T.Buffer[(1,), "int32"],
+        argmax_v1: T.Buffer[(1,), "float32"],
+    ) -> None:
+        # body
+        # with T.block("root")
+        for i0, i1_0 in T.grid(1, 1):
+            for i1_1 in T.thread_binding(64, thread="threadIdx.x"):
+                with T.block("argmax"):
+                    i = T.axis.spatial(1, i0)
+                    k = T.axis.reduce(32, i1_0 * 64 + i1_1)
+                    T.where(i1_0 * 64 + i1_1 < 32)
+                    T.reads(idx[i, k], val[i, k])
+                    T.writes(argmax_v0[i], argmax_v1[i])
+                    with T.init():
+                        argmax_v0[i] = -1
+                        argmax_v1[i] = T.float32(-3.4028234663852886e38)
+                    v_argmax_v0: T.int32 = T.Select(
+                        argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]
+                    )
+                    v_argmax_v1: T.float32 = T.Select(
+                        argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]
+                    )
+                    argmax_v0[i] = v_argmax_v0
+                    argmax_v1[i] = v_argmax_v1
+
+    decision_0 = []  # type: ignore
+    decision_1 = [
+        ("SampleCategorical", 4),
+    ]
+
+    mod = argmax_32
+    actual = ms.TuneContext(
+        mod=mod,
+        target=Target("nvidia/geforce-rtx-3090", host="llvm"),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules=get_rules("cuda", ms.schedule_rule.CrossThreadReduction),
+        task_name="test",
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[argmax_0, argmax_1],
+        expected_decisions=[decision_0, decision_1],
+    )
+
+
 if __name__ == "__main__":
     test_gpu_softmax_mn()
     test_gpu_softmax_mn_after_inline()
     test_gpu_batch_norm_bmn()
     test_gpu_argmax()
+    test_gpu_argmax_32()

Reply via email to