This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 608553462b [TIR][Schedule] Fix decompose reduction with thread binding 
loops (#15465)
608553462b is described below

commit 608553462b10844bd65a50a5803b40791cbf5c72
Author: Wuwei Lin <[email protected]>
AuthorDate: Fri Aug 4 00:54:59 2023 -0700

    [TIR][Schedule] Fix decompose reduction with thread binding loops (#15465)
    
    This PR fixes decompose reduction with thread binding loops.
---
 src/tir/schedule/primitive/reduction.cc            | 10 ++++++-
 .../python/unittest/test_tir_schedule_reduction.py | 33 ++++++++++++++++++++++
 2 files changed, 42 insertions(+), 1 deletion(-)

diff --git a/src/tir/schedule/primitive/reduction.cc 
b/src/tir/schedule/primitive/reduction.cc
index cade5457b0..e1c90cc645 100644
--- a/src/tir/schedule/primitive/reduction.cc
+++ b/src/tir/schedule/primitive/reduction.cc
@@ -271,11 +271,19 @@ StmtSRef DecomposeReduction(ScheduleState self, const 
StmtSRef& block_sref,
     Var old_loop_var = old_loop->loop_var;
     Var new_loop_var = old_loop_var.copy_with_suffix("_init");
     loop_var_map[old_loop_var] = new_loop_var;
+    Optional<IterVar> opt_thread_binding = old_loop->thread_binding;
+    if (opt_thread_binding) {
+      auto thread_binding = opt_thread_binding.value();
+      auto new_var = thread_binding->var.copy_with_suffix("");
+      thread_binding.CopyOnWrite()->var = new_var;
+      opt_thread_binding = thread_binding;
+    }
     body = For(/*loop_var=*/new_loop_var,
                /*min=*/old_loop->min,
                /*extent=*/old_loop->extent,
                /*kind=*/old_loop->kind,
-               /*body=*/body);
+               /*body=*/body,
+               /*thread_binding=*/opt_thread_binding);
   }
   body = Substitute(body, loop_var_map);
   // Step 6. Mutate IR
diff --git a/tests/python/unittest/test_tir_schedule_reduction.py 
b/tests/python/unittest/test_tir_schedule_reduction.py
index a1e5ed74c2..4ed3c6178f 100644
--- a/tests/python/unittest/test_tir_schedule_reduction.py
+++ b/tests/python/unittest/test_tir_schedule_reduction.py
@@ -353,5 +353,38 @@ def test_decompose_reduction_nested_block():
     verify_trace_roundtrip(sch, mod=nested_block)
 
 
+class TestDecomposeReductionWithThreadBinding(tvm.testing.CompareBeforeAfter):
+    def transform(self):
+        def func(mod):
+            sch = tir.Schedule(mod)
+            t, _ = sch.get_loops("B")
+            sch.decompose_reduction("B", t)
+            return sch.mod
+
+        return func
+
+    @T.prim_func
+    def before(A: T.Buffer((32, 16), "float32"), B: T.Buffer((32,), 
"float32")):
+        for t in T.thread_binding(0, 32, thread="threadIdx.x"):
+            for r in T.serial(16):
+                with T.block("B"):
+                    vi, vr = T.axis.remap("SR", [t, r])
+                    with T.init():
+                        B[vi] = T.float32(0)
+                    B[vi] += A[vi, vr]
+
+    @T.prim_func
+    def expected(A: T.Buffer((32, 16), "float32"), B: T.Buffer((32,), 
"float32")):
+        for t_init in T.thread_binding(0, 32, thread="threadIdx.x"):
+            with T.block("B_init"):
+                vi = T.axis.remap("S", [t_init])
+                B[vi] = T.float32(0)
+        for t in T.thread_binding(0, 32, thread="threadIdx.x"):
+            for r in T.serial(16):
+                with T.block("B"):
+                    vi, vr = T.axis.remap("SR", [t, r])
+                    B[vi] += A[vi, vr]
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to