This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0d7e2ec [TIR] For-kind inheritance in decompose-reduction (#9814)
0d7e2ec is described below
commit 0d7e2ec129e746d19705ec6eb32e939543afff51
Author: Ruihang Lai <[email protected]>
AuthorDate: Fri Dec 31 11:34:20 2021 +0800
[TIR] For-kind inheritance in decompose-reduction (#9814)
---
src/tir/schedule/primitive/reduction.cc | 2 +-
.../python/unittest/test_tir_schedule_reduction.py | 39 ++++++++++++++++++++++
2 files changed, 40 insertions(+), 1 deletion(-)
diff --git a/src/tir/schedule/primitive/reduction.cc
b/src/tir/schedule/primitive/reduction.cc
index 096e616..72ea199 100644
--- a/src/tir/schedule/primitive/reduction.cc
+++ b/src/tir/schedule/primitive/reduction.cc
@@ -281,7 +281,7 @@ StmtSRef DecomposeReduction(ScheduleState self, const
StmtSRef& block_sref,
body = For(/*loop_var=*/new_loop_var,
/*min=*/old_loop->min,
/*extent=*/old_loop->extent,
- /*kind=*/ForKind::kSerial,
+ /*kind=*/old_loop->kind,
/*body=*/body);
}
body = Substitute(body, loop_var_map);
diff --git a/tests/python/unittest/test_tir_schedule_reduction.py
b/tests/python/unittest/test_tir_schedule_reduction.py
index 5f5daa1..5ad366b 100644
--- a/tests/python/unittest/test_tir_schedule_reduction.py
+++ b/tests/python/unittest/test_tir_schedule_reduction.py
@@ -185,6 +185,34 @@ def matmul_decompose_with_annotation(a: T.handle, b:
T.handle, c: T.handle) -> N
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
[email protected]_func
+def colsum_with_vectorization(a: T.handle, b: T.handle) -> None:
+ A = T.match_buffer(a, [128, 32], dtype="float32")
+ B = T.match_buffer(b, [32], dtype="float32")
+ for k in T.serial(0, 128):
+ for i in T.vectorized(0, 32):
+ with T.block("B"):
+ vk, vi = T.axis.remap("RS", [k, i])
+ with T.init():
+ B[vi] = T.float32(0)
+ B[vi] = B[vi] + A[vk, vi]
+
+
[email protected]_func
+def colsum_decompose_with_vectorization(a: T.handle, b: T.handle) -> None:
+ A = T.match_buffer(a, [128, 32], dtype="float32")
+ B = T.match_buffer(b, [32], dtype="float32")
+ for i in T.vectorized(0, 32):
+ with T.block("B_init"):
+ vi = T.axis.S(32, i)
+ B[vi] = T.float32(0)
+ for k in T.serial(0, 128):
+ for i in T.vectorized(0, 32):
+ with T.block("B"):
+ vk, vi = T.axis.remap("RS", [k, i])
+ B[vi] = B[vi] + A[vk, vi]
+
+
# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
@@ -243,5 +271,16 @@ def test_reduction_decompose_with_annotation():
verify_trace_roundtrip(s, mod=matmul_with_annotation)
+def test_reduction_decompose_with_different_for_kind():
+ s = tir.Schedule(colsum_with_vectorization, debug_mask="all")
+ B = s.get_block("B")
+ k, _ = s.get_loops(B)
+ B_init = s.decompose_reduction(B, k)
+ tvm.ir.assert_structural_equal(s.mod["main"],
colsum_decompose_with_vectorization)
+ assert s.get(B).same_as(s.get(s.get_block("B_update")))
+ assert s.get(B_init).same_as(s.get(s.get_block("B_init")))
+ verify_trace_roundtrip(s, mod=colsum_with_vectorization)
+
+
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))