This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0d7e2ec  [TIR] For-kind inheritance in decompose-reduction (#9814)
0d7e2ec is described below

commit 0d7e2ec129e746d19705ec6eb32e939543afff51
Author: Ruihang Lai <[email protected]>
AuthorDate: Fri Dec 31 11:34:20 2021 +0800

    [TIR] For-kind inheritance in decompose-reduction (#9814)
---
 src/tir/schedule/primitive/reduction.cc            |  2 +-
 .../python/unittest/test_tir_schedule_reduction.py | 39 ++++++++++++++++++++++
 2 files changed, 40 insertions(+), 1 deletion(-)

diff --git a/src/tir/schedule/primitive/reduction.cc 
b/src/tir/schedule/primitive/reduction.cc
index 096e616..72ea199 100644
--- a/src/tir/schedule/primitive/reduction.cc
+++ b/src/tir/schedule/primitive/reduction.cc
@@ -281,7 +281,7 @@ StmtSRef DecomposeReduction(ScheduleState self, const 
StmtSRef& block_sref,
     body = For(/*loop_var=*/new_loop_var,
                /*min=*/old_loop->min,
                /*extent=*/old_loop->extent,
-               /*kind=*/ForKind::kSerial,
+               /*kind=*/old_loop->kind,
                /*body=*/body);
   }
   body = Substitute(body, loop_var_map);
diff --git a/tests/python/unittest/test_tir_schedule_reduction.py 
b/tests/python/unittest/test_tir_schedule_reduction.py
index 5f5daa1..5ad366b 100644
--- a/tests/python/unittest/test_tir_schedule_reduction.py
+++ b/tests/python/unittest/test_tir_schedule_reduction.py
@@ -185,6 +185,34 @@ def matmul_decompose_with_annotation(a: T.handle, b: 
T.handle, c: T.handle) -> N
             C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
 
 
[email protected]_func
+def colsum_with_vectorization(a: T.handle, b: T.handle) -> None:
+    A = T.match_buffer(a, [128, 32], dtype="float32")
+    B = T.match_buffer(b, [32], dtype="float32")
+    for k in T.serial(0, 128):
+        for i in T.vectorized(0, 32):
+            with T.block("B"):
+                vk, vi = T.axis.remap("RS", [k, i])
+                with T.init():
+                    B[vi] = T.float32(0)
+                B[vi] = B[vi] + A[vk, vi]
+
+
[email protected]_func
+def colsum_decompose_with_vectorization(a: T.handle, b: T.handle) -> None:
+    A = T.match_buffer(a, [128, 32], dtype="float32")
+    B = T.match_buffer(b, [32], dtype="float32")
+    for i in T.vectorized(0, 32):
+        with T.block("B_init"):
+            vi = T.axis.S(32, i)
+            B[vi] = T.float32(0)
+    for k in T.serial(0, 128):
+        for i in T.vectorized(0, 32):
+            with T.block("B"):
+                vk, vi = T.axis.remap("RS", [k, i])
+                B[vi] = B[vi] + A[vk, vi]
+
+
 # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
 
 
@@ -243,5 +271,16 @@ def test_reduction_decompose_with_annotation():
     verify_trace_roundtrip(s, mod=matmul_with_annotation)
 
 
+def test_reduction_decompose_with_different_for_kind():
+    s = tir.Schedule(colsum_with_vectorization, debug_mask="all")
+    B = s.get_block("B")
+    k, _ = s.get_loops(B)
+    B_init = s.decompose_reduction(B, k)
+    tvm.ir.assert_structural_equal(s.mod["main"], 
colsum_decompose_with_vectorization)
+    assert s.get(B).same_as(s.get(s.get_block("B_update")))
+    assert s.get(B_init).same_as(s.get(s.get_block("B_init")))
+    verify_trace_roundtrip(s, mod=colsum_with_vectorization)
+
+
 if __name__ == "__main__":
     sys.exit(pytest.main([__file__] + sys.argv[1:]))

Reply via email to