This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fe24fa9840 [Bugfix][MetaSchedule] Auto-bind when there are no spatial 
loops (#11570)
fe24fa9840 is described below

commit fe24fa9840500b9217f5773e65a764a16e998a66
Author: Junru Shao <[email protected]>
AuthorDate: Sat Jun 4 01:37:23 2022 -0700

    [Bugfix][MetaSchedule] Auto-bind when there are no spatial loops (#11570)
---
 src/meta_schedule/schedule_rule/auto_bind.cc       | 38 +++++++++++++-----
 .../test_meta_schedule_schedule_rule_auto_bind.py  | 45 +++++++++++++++++++++-
 2 files changed, 72 insertions(+), 11 deletions(-)

diff --git a/src/meta_schedule/schedule_rule/auto_bind.cc 
b/src/meta_schedule/schedule_rule/auto_bind.cc
index 9c16856557..61f8e4f6fc 100644
--- a/src/meta_schedule/schedule_rule/auto_bind.cc
+++ b/src/meta_schedule/schedule_rule/auto_bind.cc
@@ -72,7 +72,7 @@ void BindBlockThreadIdx(const tir::Schedule& sch, const 
tir::BlockRV& block_rv,
   if (i_multi_child == -1) {
     i_multi_child = n;
   }
-  if ((i_block_idx != -1 && i_thread_idx != -1) || i_spatial_loop == -1) {
+  if (i_block_idx != -1 && i_thread_idx != -1) {
     return;
   }
   if (i_block_idx != -1 && i_thread_idx == -1) {
@@ -80,16 +80,34 @@ void BindBlockThreadIdx(const tir::Schedule& sch, const 
tir::BlockRV& block_rv,
     throw;
   }
   LoopRV loop_rv{nullptr};
-  if (i_block_idx == -1 && i_thread_idx != -1) {
-    int num_fuse = std::min(std::min(i_multi_child, i_thread_idx), 
i_spatial_loop + 1);
+  {
     Array<LoopRV> loop_rvs = sch->GetLoops(block_rv);
-    loop_rv = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + num_fuse});
-    sch->Bind(loop_rv, "blockIdx.x");
-    return;
-  } else {  // i_block_idx == -1 && i_thread_idx == -1
-    Array<LoopRV> loop_rvs = sch->GetLoops(block_rv);
-    int num_fuse = std::min(i_multi_child, i_spatial_loop + 1);
-    loop_rv = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + num_fuse});
+    if (i_spatial_loop == -1) {
+      Array<LoopRV> split = sch->Split(loop_rvs[0], {Integer(1), NullOpt});
+      ICHECK_EQ(split.size(), 2);
+      loop_rvs.Set(0, split[1]);
+      loop_rvs.insert(loop_rvs.begin(), split[0]);
+      i_spatial_loop = 0;
+      if (i_block_idx != -1) {
+        i_block_idx += 1;
+      }
+      if (i_thread_idx != -1) {
+        i_thread_idx += 1;
+      }
+      if (i_multi_child != -1) {
+        i_multi_child += 1;
+      }
+    }
+    if (i_block_idx == -1 && i_thread_idx != -1) {
+      int num_fuse = std::min(std::min(i_multi_child, i_thread_idx), 
i_spatial_loop + 1);
+      Array<LoopRV> loop_rvs = sch->GetLoops(block_rv);
+      loop_rv = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + num_fuse});
+      sch->Bind(loop_rv, "blockIdx.x");
+      return;
+    } else {  // i_block_idx == -1 && i_thread_idx == -1
+      int num_fuse = std::min(i_multi_child, i_spatial_loop + 1);
+      loop_rv = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + num_fuse});
+    }
   }
   int64_t extent = -1;
   if (const int64_t* e = GetLoopIntExtent(sch->Get(loop_rv).get())) {
diff --git 
a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py 
b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py
index bd0a24e8b6..80a72a4e93 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py
@@ -20,8 +20,8 @@ from tvm.meta_schedule.space_generator.post_order_apply 
import PostOrderApply
 from tvm.meta_schedule.testing.schedule_rule import auto_bind
 from tvm.meta_schedule.testing.space_generation import check_trace
 from tvm.meta_schedule.tune_context import TuneContext
-from tvm.target import Target
 from tvm.script import tir as T
+from tvm.target import Target
 
 
 @T.prim_func
@@ -34,6 +34,25 @@ def element_wise(var_A: T.handle, var_B: T.handle) -> None:
             B[vi, vj] = A[vi, vj] + 1.0
 
 
[email protected]_func
+def reduction_loop_only(
+    A: T.Buffer[2, "float32"],
+    B: T.Buffer[2, "float32"],
+    C: T.Buffer[(), "float32"],
+) -> None:
+    # function attr dict
+    T.func_attr({"global_symbol": "main", "tir.noalias": True})
+    # body
+    for i0 in T.serial(2):
+        with T.block("C"):
+            k0 = T.axis.reduce(2, i0)
+            T.reads(A[k0], B[k0])
+            T.writes(C[()])
+            with T.init():
+                C[()] = T.float32(1.0)
+            C[()] = T.min(C[()], A[k0] / B[k0])
+
+
 def _create_context(mod, target, rule) -> TuneContext:
     ctx = TuneContext(
         mod=mod,
@@ -71,5 +90,29 @@ def test_cuda_element_wise():
     check_trace(spaces, expected)
 
 
+def test_cuda_reduction_loop_only():
+    expected = [
+        [
+            'b0 = sch.get_block(name="C", func_name="main")',
+            "l1, = sch.get_loops(block=b0)",
+            "l2, l3 = sch.split(loop=l1, factors=[1, None])",
+            "l4 = sch.fuse(l2)",
+            "l5, l6 = sch.split(loop=l4, factors=[None, 1])",
+            'sch.bind(loop=l5, thread_axis="blockIdx.x")',
+            'sch.bind(loop=l6, thread_axis="threadIdx.x")',
+        ]
+    ]
+    target = Target("nvidia/geforce-rtx-3080", host="llvm")
+    ctx = _create_context(
+        reduction_loop_only,
+        target=target,
+        rule=auto_bind(target=target),
+    )
+    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
+    assert len(spaces) == 1
+    check_trace(spaces, expected)
+
+
 if __name__ == "__main__":
     test_cuda_element_wise()
+    test_cuda_reduction_loop_only()

Reply via email to