This is an automated email from the ASF dual-hosted git repository.

bohan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fa5460242e [MetaSchedule] Enhance CPU auto vectorization (#11404)
fa5460242e is described below

commit fa5460242e31cea3df7db8efe42da57196eba25e
Author: Junru Shao <[email protected]>
AuthorDate: Sat May 21 07:21:15 2022 -0700

    [MetaSchedule] Enhance CPU auto vectorization (#11404)
---
 .../postproc/rewrite_parallel_vectorize_unroll.cc  |  2 +-
 ...e_postproc_rewrite_parallel_vectorize_unroll.py | 91 +++++++++++++++++++++-
 2 files changed, 89 insertions(+), 4 deletions(-)

diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc 
b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
index 69e8dfb858..001c97645b 100644
--- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
+++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
@@ -207,7 +207,7 @@ void AdjustParallelVectorize(const Schedule& sch, const 
BlockRV& block_rv,
         continue;
       } else if (prev_used_iter == -1) {
         // the stride of last axis is not 1 means the memory access is not 
contiguous
-        if (strides[i] != 1) {
+        if (strides[i] != 1 && fusible != 0) {
           break;
         }
         fusible++;
diff --git 
a/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
 
b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
index 9988e874b8..f9b71bfdb6 100644
--- 
a/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
+++ 
b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
@@ -16,9 +16,8 @@
 # under the License.
 # pylint: 
disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
 import tvm
-from tvm.script import tir as T
-
 from tvm.meta_schedule.postproc import RewriteParallelVectorizeUnroll
+from tvm.script import tir as T
 from tvm.tir.schedule import Schedule
 
 # pylint: 
disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant
@@ -70,6 +69,85 @@ def Move_PUV0(a: T.handle, b: T.handle) -> None:
                         T.writes([B[vi, vj, vk]])
                         B[vi, vj, vk] = A[vi, vj, vk]
 
+
[email protected]_module
+class Fused_NN_Dense:
+    @T.prim_func
+    def main(placeholder: T.Buffer[(64, 768), "float32"], placeholder_1: 
T.Buffer[(768, 768), "float32"], T_matmul_NT: T.Buffer[(64, 768), "float32"]) 
-> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True, 
"layout_free_placeholders": [1]})
+        # body
+        # with T.block("root")
+        for i0, i1, i2 in T.grid(64, 768, 768):
+            with T.block("T_matmul_NT"):
+                i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+                T.reads(placeholder[i, k], placeholder_1[j, k])
+                T.writes(T_matmul_NT[i, j])
+                with T.init():
+                    T_matmul_NT[i, j] = T.float32(0)
+                T_matmul_NT[i, j] = T_matmul_NT[i, j] + placeholder[i, k] * 
placeholder_1[j, k]
+
[email protected]_func
+def before_matmul_vectorize(
+    placeholder: T.Buffer[(64, 768), "float32"],
+    placeholder_1: T.Buffer[(768, 768), "float32"],
+    T_matmul_NT: T.Buffer[(64, 768), "float32"],
+) -> None:
+    T.func_attr({"global_symbol": "main", "tir.noalias": True, 
"layout_free_placeholders": [1]})
+    with T.block("root"):
+        T.reads()
+        T.writes()
+        T.block_attr({"meta_schedule.vectorize":64})
+        T_matmul_NT_global = T.alloc_buffer([64, 768], dtype="float32")
+        for i0_0, i1_0, i0_1, i1_1 in T.grid(1, 16, 1, 3):
+            for i2_0, i0_2, i1_2, i2_1, i0_3, i1_3 in T.grid(48, 8, 1, 16, 8, 
16):
+                with T.block("T_matmul_NT"):
+                    i = T.axis.spatial(64, i0_2 * 8 + i0_3)
+                    j = T.axis.spatial(768, i1_0 * 48 + i1_1 * 16 + i1_3)
+                    k = T.axis.reduce(768, i2_0 * 16 + i2_1)
+                    T.reads(placeholder[i, k], placeholder_1[j, k])
+                    T.writes(T_matmul_NT_global[i, j])
+                    with T.init():
+                        T_matmul_NT_global[i, j] = T.float32(0)
+                    T_matmul_NT_global[i, j] = T_matmul_NT_global[i, j] + 
placeholder[i, k] * placeholder_1[j, k]
+            for ax0, ax1 in T.grid(64, 16):
+                with T.block("T_matmul_NT_global"):
+                    v0 = T.axis.spatial(64, ax0)
+                    v1 = T.axis.spatial(768, i1_0 * 48 + i1_1 * 16 + ax1)
+                    T.reads(T_matmul_NT_global[v0, v1])
+                    T.writes(T_matmul_NT[v0, v1])
+                    T_matmul_NT[v0, v1] = T_matmul_NT_global[v0, v1]
+
[email protected]_func
+def after_matmul_vectorize(
+    placeholder: T.Buffer[(64, 768), "float32"],
+    placeholder_1: T.Buffer[(768, 768), "float32"],
+    T_matmul_NT: T.Buffer[(64, 768), "float32"],
+) -> None:
+    T.func_attr({"global_symbol": "main", "tir.noalias": True, 
"layout_free_placeholders": [1]})
+    T_matmul_NT_global = T.alloc_buffer([64, 768], dtype="float32")
+    for i0_0, i1_0, i0_1, i1_1 in T.grid(1, 16, 1, 3):
+        for i2_0, i0_2, i1_2, i2_1, i0_3 in T.grid(48, 8, 1, 16, 8):
+            for i1_3_fused in T.vectorized(16):
+                with T.block("T_matmul_NT"):
+                    i = T.axis.spatial(64, i0_2 * 8 + i0_3)
+                    j = T.axis.spatial(768, i1_0 * 48 + i1_1 * 16 + i1_3_fused)
+                    k = T.axis.reduce(768, i2_0 * 16 + i2_1)
+                    T.reads(placeholder[i, k], placeholder_1[j, k])
+                    T.writes(T_matmul_NT_global[i, j])
+                    with T.init():
+                        T_matmul_NT_global[i, j] = T.float32(0)
+                    T_matmul_NT_global[i, j] = T_matmul_NT_global[i, j] + 
placeholder[i, k] * placeholder_1[j, k]
+        for ax0 in T.serial(64):
+            for ax1_fused in T.vectorized(16):
+                with T.block("T_matmul_NT_global"):
+                    v0 = T.axis.spatial(64, ax0)
+                    v1 = T.axis.spatial(768, i1_0 * 48 + i1_1 * 16 + ax1_fused)
+                    T.reads(T_matmul_NT_global[v0, v1])
+                    T.writes(T_matmul_NT[v0, v1])
+                    T_matmul_NT[v0, v1] = T_matmul_NT_global[v0, v1]
+
+
 # fmt: on
 # pylint: 
enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable
 
@@ -78,10 +156,17 @@ def 
test_meta_schedule_postproc_rewrite_parallel_unroll_vectorize():
     postproc = RewriteParallelVectorizeUnroll()
     sch = Schedule(Move_PUV)
     assert postproc.apply(sch)
-    print(sch.mod["main"].script())
     mod = tvm.tir.transform.Simplify()(sch.mod)
     tvm.ir.assert_structural_equal(mod["main"], Move_PUV0)
 
 
+def test_vectorize_inner_loop():
+    sch = Schedule(before_matmul_vectorize)
+    rule = RewriteParallelVectorizeUnroll()
+    assert rule.apply(sch)
+    tvm.ir.assert_structural_equal(sch.mod["main"], after_matmul_vectorize)
+
+
 if __name__ == "__main__":
     test_meta_schedule_postproc_rewrite_parallel_unroll_vectorize()
+    test_vectorize_inner_loop()

Reply via email to