This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d1871a6957 [MetaSchedule] Relax conditions of rule Cross-Thread
Reduction (#12825)
d1871a6957 is described below
commit d1871a6957b4f469f1b994aa6c89e0d209b64f05
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Sep 17 22:03:17 2022 -0400
[MetaSchedule] Relax conditions of rule Cross-Thread Reduction (#12825)
This PR relaxes the conditions of Meta-Schedule schedule rule
CrossThreadReduction. The rules are previously a bit over-strict, and some
workloads with small reduction loop length are unable to be optimized by
cross-thread reduction automatically. In this PR, we relax the rules so that
such workloads can be optimized.
---
src/tir/schedule/analysis/analysis.cc | 6 +-
...chedule_schedule_rule_cross_thread_reduction.py | 98 ++++++++++++++++++++++
2 files changed, 100 insertions(+), 4 deletions(-)
diff --git a/src/tir/schedule/analysis/analysis.cc
b/src/tir/schedule/analysis/analysis.cc
index 4f78b0c9cd..e39f7b2554 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -1640,11 +1640,9 @@ bool NeedsRFactorOrCrossThreadReduction(const
tir::ScheduleState& self, //
if (NeedsMultiLevelTiling(self, block_sref)) {
// Do not use rfactor/cross-thread-reduction if we have enough parallelism
on spatial loops.
return !(cum_space_len >= cum_reduce_len || cum_space_len >
max_parallel_extent);
- } else if (cum_reduce_len > 1) {
- // Always try rfactor/cross-thread-reduction for other reduction blocks.
- return cum_reduce_len > max_parallel_basic;
} else {
- return false;
+ // Always try rfactor/cross-thread-reduction for other reduction blocks.
+ return cum_reduce_len > 1;
}
}
diff --git
a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
index 4278638a1a..718b264bdd 100644
---
a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
+++
b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
@@ -589,6 +589,28 @@ def argmax(
argmax_v1[i] = v_argmax_v1
[email protected]_func
+def argmax_32(
+ idx: T.Buffer[(1, 32), "int32"],
+ val: T.Buffer[(1, 32), "float32"],
+ argmax_v0: T.Buffer[(1,), "int32"],
+ argmax_v1: T.Buffer[(1,), "float32"],
+) -> None:
+ for i0, i1 in T.grid(1, 32):
+ with T.block("argmax"):
+ i = T.axis.spatial(1, i0)
+ k = T.axis.reduce(32, i1)
+ T.reads(idx[i, k], val[i, k])
+ T.writes(argmax_v0[i], argmax_v1[i])
+ with T.init():
+ argmax_v0[i] = -1
+ argmax_v1[i] = T.min_value("float32")
+ v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k],
argmax_v0[i], idx[i, k])
+ v_argmax_v1: T.float32 = T.Select(argmax_v1[i] >= val[i, k],
argmax_v1[i], val[i, k])
+ argmax_v0[i] = v_argmax_v0
+ argmax_v1[i] = v_argmax_v1
+
+
def test_gpu_argmax():
@T.prim_func
def argmax_0(
@@ -663,8 +685,84 @@ def test_gpu_argmax():
)
+def test_gpu_argmax_32():
+ @T.prim_func
+ def argmax_0(
+ idx: T.Buffer[(1, 32), "int32"],
+ val: T.Buffer[(1, 32), "float32"],
+ argmax_v0: T.Buffer[(1,), "int32"],
+ argmax_v1: T.Buffer[(1,), "float32"],
+ ) -> None:
+ # body
+ # with T.block("root")
+ for i0, i1 in T.grid(1, 32):
+ with T.block("argmax"):
+ i, k = T.axis.remap("SR", [i0, i1])
+ T.reads(idx[i, k], val[i, k])
+ T.writes(argmax_v0[i], argmax_v1[i])
+ with T.init():
+ argmax_v0[i] = -1
+ argmax_v1[i] = T.float32(-3.4028234663852886e38)
+ v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k],
argmax_v0[i], idx[i, k])
+ v_argmax_v1: T.float32 = T.Select(
+ argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]
+ )
+ argmax_v0[i] = v_argmax_v0
+ argmax_v1[i] = v_argmax_v1
+
+ @T.prim_func
+ def argmax_1(
+ idx: T.Buffer[(1, 32), "int32"],
+ val: T.Buffer[(1, 32), "float32"],
+ argmax_v0: T.Buffer[(1,), "int32"],
+ argmax_v1: T.Buffer[(1,), "float32"],
+ ) -> None:
+ # body
+ # with T.block("root")
+ for i0, i1_0 in T.grid(1, 1):
+ for i1_1 in T.thread_binding(64, thread="threadIdx.x"):
+ with T.block("argmax"):
+ i = T.axis.spatial(1, i0)
+ k = T.axis.reduce(32, i1_0 * 64 + i1_1)
+ T.where(i1_0 * 64 + i1_1 < 32)
+ T.reads(idx[i, k], val[i, k])
+ T.writes(argmax_v0[i], argmax_v1[i])
+ with T.init():
+ argmax_v0[i] = -1
+ argmax_v1[i] = T.float32(-3.4028234663852886e38)
+ v_argmax_v0: T.int32 = T.Select(
+ argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]
+ )
+ v_argmax_v1: T.float32 = T.Select(
+ argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]
+ )
+ argmax_v0[i] = v_argmax_v0
+ argmax_v1[i] = v_argmax_v1
+
+ decision_0 = [] # type: ignore
+ decision_1 = [
+ ("SampleCategorical", 4),
+ ]
+
+ mod = argmax_32
+ actual = ms.TuneContext(
+ mod=mod,
+ target=Target("nvidia/geforce-rtx-3090", host="llvm"),
+ space_generator=ms.space_generator.PostOrderApply(),
+ sch_rules=get_rules("cuda", ms.schedule_rule.CrossThreadReduction),
+ task_name="test",
+ ).generate_design_space()
+ check_sketches(
+ mod,
+ sketches=actual,
+ expected_mods=[argmax_0, argmax_1],
+ expected_decisions=[decision_0, decision_1],
+ )
+
+
if __name__ == "__main__":
test_gpu_softmax_mn()
test_gpu_softmax_mn_after_inline()
test_gpu_batch_norm_bmn()
test_gpu_argmax()
+ test_gpu_argmax_32()