This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fe24fa9840 [Bugfix][MetaSchedule] Auto-bind when there are no spatial
loops (#11570)
fe24fa9840 is described below
commit fe24fa9840500b9217f5773e65a764a16e998a66
Author: Junru Shao <[email protected]>
AuthorDate: Sat Jun 4 01:37:23 2022 -0700
[Bugfix][MetaSchedule] Auto-bind when there are no spatial loops (#11570)
---
src/meta_schedule/schedule_rule/auto_bind.cc | 38 +++++++++++++-----
.../test_meta_schedule_schedule_rule_auto_bind.py | 45 +++++++++++++++++++++-
2 files changed, 72 insertions(+), 11 deletions(-)
diff --git a/src/meta_schedule/schedule_rule/auto_bind.cc
b/src/meta_schedule/schedule_rule/auto_bind.cc
index 9c16856557..61f8e4f6fc 100644
--- a/src/meta_schedule/schedule_rule/auto_bind.cc
+++ b/src/meta_schedule/schedule_rule/auto_bind.cc
@@ -72,7 +72,7 @@ void BindBlockThreadIdx(const tir::Schedule& sch, const
tir::BlockRV& block_rv,
if (i_multi_child == -1) {
i_multi_child = n;
}
- if ((i_block_idx != -1 && i_thread_idx != -1) || i_spatial_loop == -1) {
+ if (i_block_idx != -1 && i_thread_idx != -1) {
return;
}
if (i_block_idx != -1 && i_thread_idx == -1) {
@@ -80,16 +80,34 @@ void BindBlockThreadIdx(const tir::Schedule& sch, const
tir::BlockRV& block_rv,
throw;
}
LoopRV loop_rv{nullptr};
- if (i_block_idx == -1 && i_thread_idx != -1) {
- int num_fuse = std::min(std::min(i_multi_child, i_thread_idx),
i_spatial_loop + 1);
+ {
Array<LoopRV> loop_rvs = sch->GetLoops(block_rv);
- loop_rv = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + num_fuse});
- sch->Bind(loop_rv, "blockIdx.x");
- return;
- } else { // i_block_idx == -1 && i_thread_idx == -1
- Array<LoopRV> loop_rvs = sch->GetLoops(block_rv);
- int num_fuse = std::min(i_multi_child, i_spatial_loop + 1);
- loop_rv = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + num_fuse});
+ if (i_spatial_loop == -1) {
+ Array<LoopRV> split = sch->Split(loop_rvs[0], {Integer(1), NullOpt});
+ ICHECK_EQ(split.size(), 2);
+ loop_rvs.Set(0, split[1]);
+ loop_rvs.insert(loop_rvs.begin(), split[0]);
+ i_spatial_loop = 0;
+ if (i_block_idx != -1) {
+ i_block_idx += 1;
+ }
+ if (i_thread_idx != -1) {
+ i_thread_idx += 1;
+ }
+ if (i_multi_child != -1) {
+ i_multi_child += 1;
+ }
+ }
+ if (i_block_idx == -1 && i_thread_idx != -1) {
+ int num_fuse = std::min(std::min(i_multi_child, i_thread_idx),
i_spatial_loop + 1);
+ Array<LoopRV> loop_rvs = sch->GetLoops(block_rv);
+ loop_rv = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + num_fuse});
+ sch->Bind(loop_rv, "blockIdx.x");
+ return;
+ } else { // i_block_idx == -1 && i_thread_idx == -1
+ int num_fuse = std::min(i_multi_child, i_spatial_loop + 1);
+ loop_rv = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + num_fuse});
+ }
}
int64_t extent = -1;
if (const int64_t* e = GetLoopIntExtent(sch->Get(loop_rv).get())) {
diff --git
a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py
b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py
index bd0a24e8b6..80a72a4e93 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py
@@ -20,8 +20,8 @@ from tvm.meta_schedule.space_generator.post_order_apply
import PostOrderApply
from tvm.meta_schedule.testing.schedule_rule import auto_bind
from tvm.meta_schedule.testing.space_generation import check_trace
from tvm.meta_schedule.tune_context import TuneContext
-from tvm.target import Target
from tvm.script import tir as T
+from tvm.target import Target
@T.prim_func
@@ -34,6 +34,25 @@ def element_wise(var_A: T.handle, var_B: T.handle) -> None:
B[vi, vj] = A[vi, vj] + 1.0
[email protected]_func
+def reduction_loop_only(
+ A: T.Buffer[2, "float32"],
+ B: T.Buffer[2, "float32"],
+ C: T.Buffer[(), "float32"],
+) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ # body
+ for i0 in T.serial(2):
+ with T.block("C"):
+ k0 = T.axis.reduce(2, i0)
+ T.reads(A[k0], B[k0])
+ T.writes(C[()])
+ with T.init():
+ C[()] = T.float32(1.0)
+ C[()] = T.min(C[()], A[k0] / B[k0])
+
+
def _create_context(mod, target, rule) -> TuneContext:
ctx = TuneContext(
mod=mod,
@@ -71,5 +90,29 @@ def test_cuda_element_wise():
check_trace(spaces, expected)
+def test_cuda_reduction_loop_only():
+ expected = [
+ [
+ 'b0 = sch.get_block(name="C", func_name="main")',
+ "l1, = sch.get_loops(block=b0)",
+ "l2, l3 = sch.split(loop=l1, factors=[1, None])",
+ "l4 = sch.fuse(l2)",
+ "l5, l6 = sch.split(loop=l4, factors=[None, 1])",
+ 'sch.bind(loop=l5, thread_axis="blockIdx.x")',
+ 'sch.bind(loop=l6, thread_axis="threadIdx.x")',
+ ]
+ ]
+ target = Target("nvidia/geforce-rtx-3080", host="llvm")
+ ctx = _create_context(
+ reduction_loop_only,
+ target=target,
+ rule=auto_bind(target=target),
+ )
+ spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
+ assert len(spaces) == 1
+ check_trace(spaces, expected)
+
+
if __name__ == "__main__":
test_cuda_element_wise()
+ test_cuda_reduction_loop_only()