This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 608553462b [TIR][Schedule] Fix decompose reduction with thread binding
loops (#15465)
608553462b is described below
commit 608553462b10844bd65a50a5803b40791cbf5c72
Author: Wuwei Lin <[email protected]>
AuthorDate: Fri Aug 4 00:54:59 2023 -0700
[TIR][Schedule] Fix decompose reduction with thread binding loops (#15465)
This PR fixes decompose reduction with thread binding loops.
---
src/tir/schedule/primitive/reduction.cc | 10 ++++++-
.../python/unittest/test_tir_schedule_reduction.py | 33 ++++++++++++++++++++++
2 files changed, 42 insertions(+), 1 deletion(-)
diff --git a/src/tir/schedule/primitive/reduction.cc
b/src/tir/schedule/primitive/reduction.cc
index cade5457b0..e1c90cc645 100644
--- a/src/tir/schedule/primitive/reduction.cc
+++ b/src/tir/schedule/primitive/reduction.cc
@@ -271,11 +271,19 @@ StmtSRef DecomposeReduction(ScheduleState self, const
StmtSRef& block_sref,
Var old_loop_var = old_loop->loop_var;
Var new_loop_var = old_loop_var.copy_with_suffix("_init");
loop_var_map[old_loop_var] = new_loop_var;
+ Optional<IterVar> opt_thread_binding = old_loop->thread_binding;
+ if (opt_thread_binding) {
+ auto thread_binding = opt_thread_binding.value();
+ auto new_var = thread_binding->var.copy_with_suffix("");
+ thread_binding.CopyOnWrite()->var = new_var;
+ opt_thread_binding = thread_binding;
+ }
body = For(/*loop_var=*/new_loop_var,
/*min=*/old_loop->min,
/*extent=*/old_loop->extent,
/*kind=*/old_loop->kind,
- /*body=*/body);
+ /*body=*/body,
+ /*thread_binding=*/opt_thread_binding);
}
body = Substitute(body, loop_var_map);
// Step 6. Mutate IR
diff --git a/tests/python/unittest/test_tir_schedule_reduction.py
b/tests/python/unittest/test_tir_schedule_reduction.py
index a1e5ed74c2..4ed3c6178f 100644
--- a/tests/python/unittest/test_tir_schedule_reduction.py
+++ b/tests/python/unittest/test_tir_schedule_reduction.py
@@ -353,5 +353,38 @@ def test_decompose_reduction_nested_block():
verify_trace_roundtrip(sch, mod=nested_block)
+class TestDecomposeReductionWithThreadBinding(tvm.testing.CompareBeforeAfter):
+ def transform(self):
+ def func(mod):
+ sch = tir.Schedule(mod)
+ t, _ = sch.get_loops("B")
+ sch.decompose_reduction("B", t)
+ return sch.mod
+
+ return func
+
+ @T.prim_func
+ def before(A: T.Buffer((32, 16), "float32"), B: T.Buffer((32,),
"float32")):
+ for t in T.thread_binding(0, 32, thread="threadIdx.x"):
+ for r in T.serial(16):
+ with T.block("B"):
+ vi, vr = T.axis.remap("SR", [t, r])
+ with T.init():
+ B[vi] = T.float32(0)
+ B[vi] += A[vi, vr]
+
+ @T.prim_func
+ def expected(A: T.Buffer((32, 16), "float32"), B: T.Buffer((32,),
"float32")):
+ for t_init in T.thread_binding(0, 32, thread="threadIdx.x"):
+ with T.block("B_init"):
+ vi = T.axis.remap("S", [t_init])
+ B[vi] = T.float32(0)
+ for t in T.thread_binding(0, 32, thread="threadIdx.x"):
+ for r in T.serial(16):
+ with T.block("B"):
+ vi, vr = T.axis.remap("SR", [t, r])
+ B[vi] += A[vi, vr]
+
+
if __name__ == "__main__":
tvm.testing.main()