This is an automated email from the ASF dual-hosted git repository.
bohan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fa5460242e [MetaSchedule] Enhance CPU auto vectorization (#11404)
fa5460242e is described below
commit fa5460242e31cea3df7db8efe42da57196eba25e
Author: Junru Shao <[email protected]>
AuthorDate: Sat May 21 07:21:15 2022 -0700
[MetaSchedule] Enhance CPU auto vectorization (#11404)
---
.../postproc/rewrite_parallel_vectorize_unroll.cc | 2 +-
...e_postproc_rewrite_parallel_vectorize_unroll.py | 91 +++++++++++++++++++++-
2 files changed, 89 insertions(+), 4 deletions(-)
diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
index 69e8dfb858..001c97645b 100644
--- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
+++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
@@ -207,7 +207,7 @@ void AdjustParallelVectorize(const Schedule& sch, const
BlockRV& block_rv,
continue;
} else if (prev_used_iter == -1) {
// the stride of last axis is not 1 means the memory access is not
contiguous
- if (strides[i] != 1) {
+ if (strides[i] != 1 && fusible != 0) {
break;
}
fusible++;
diff --git
a/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
index 9988e874b8..f9b71bfdb6 100644
---
a/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
+++
b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
@@ -16,9 +16,8 @@
# under the License.
# pylint:
disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
import tvm
-from tvm.script import tir as T
-
from tvm.meta_schedule.postproc import RewriteParallelVectorizeUnroll
+from tvm.script import tir as T
from tvm.tir.schedule import Schedule
# pylint:
disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant
@@ -70,6 +69,85 @@ def Move_PUV0(a: T.handle, b: T.handle) -> None:
T.writes([B[vi, vj, vk]])
B[vi, vj, vk] = A[vi, vj, vk]
+
[email protected]_module
+class Fused_NN_Dense:
+ @T.prim_func
+ def main(placeholder: T.Buffer[(64, 768), "float32"], placeholder_1:
T.Buffer[(768, 768), "float32"], T_matmul_NT: T.Buffer[(64, 768), "float32"])
-> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True,
"layout_free_placeholders": [1]})
+ # body
+ # with T.block("root")
+ for i0, i1, i2 in T.grid(64, 768, 768):
+ with T.block("T_matmul_NT"):
+ i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+ T.reads(placeholder[i, k], placeholder_1[j, k])
+ T.writes(T_matmul_NT[i, j])
+ with T.init():
+ T_matmul_NT[i, j] = T.float32(0)
+ T_matmul_NT[i, j] = T_matmul_NT[i, j] + placeholder[i, k] *
placeholder_1[j, k]
+
[email protected]_func
+def before_matmul_vectorize(
+ placeholder: T.Buffer[(64, 768), "float32"],
+ placeholder_1: T.Buffer[(768, 768), "float32"],
+ T_matmul_NT: T.Buffer[(64, 768), "float32"],
+) -> None:
+ T.func_attr({"global_symbol": "main", "tir.noalias": True,
"layout_free_placeholders": [1]})
+ with T.block("root"):
+ T.reads()
+ T.writes()
+ T.block_attr({"meta_schedule.vectorize":64})
+ T_matmul_NT_global = T.alloc_buffer([64, 768], dtype="float32")
+ for i0_0, i1_0, i0_1, i1_1 in T.grid(1, 16, 1, 3):
+ for i2_0, i0_2, i1_2, i2_1, i0_3, i1_3 in T.grid(48, 8, 1, 16, 8,
16):
+ with T.block("T_matmul_NT"):
+ i = T.axis.spatial(64, i0_2 * 8 + i0_3)
+ j = T.axis.spatial(768, i1_0 * 48 + i1_1 * 16 + i1_3)
+ k = T.axis.reduce(768, i2_0 * 16 + i2_1)
+ T.reads(placeholder[i, k], placeholder_1[j, k])
+ T.writes(T_matmul_NT_global[i, j])
+ with T.init():
+ T_matmul_NT_global[i, j] = T.float32(0)
+ T_matmul_NT_global[i, j] = T_matmul_NT_global[i, j] +
placeholder[i, k] * placeholder_1[j, k]
+ for ax0, ax1 in T.grid(64, 16):
+ with T.block("T_matmul_NT_global"):
+ v0 = T.axis.spatial(64, ax0)
+ v1 = T.axis.spatial(768, i1_0 * 48 + i1_1 * 16 + ax1)
+ T.reads(T_matmul_NT_global[v0, v1])
+ T.writes(T_matmul_NT[v0, v1])
+ T_matmul_NT[v0, v1] = T_matmul_NT_global[v0, v1]
+
[email protected]_func
+def after_matmul_vectorize(
+ placeholder: T.Buffer[(64, 768), "float32"],
+ placeholder_1: T.Buffer[(768, 768), "float32"],
+ T_matmul_NT: T.Buffer[(64, 768), "float32"],
+) -> None:
+ T.func_attr({"global_symbol": "main", "tir.noalias": True,
"layout_free_placeholders": [1]})
+ T_matmul_NT_global = T.alloc_buffer([64, 768], dtype="float32")
+ for i0_0, i1_0, i0_1, i1_1 in T.grid(1, 16, 1, 3):
+ for i2_0, i0_2, i1_2, i2_1, i0_3 in T.grid(48, 8, 1, 16, 8):
+ for i1_3_fused in T.vectorized(16):
+ with T.block("T_matmul_NT"):
+ i = T.axis.spatial(64, i0_2 * 8 + i0_3)
+ j = T.axis.spatial(768, i1_0 * 48 + i1_1 * 16 + i1_3_fused)
+ k = T.axis.reduce(768, i2_0 * 16 + i2_1)
+ T.reads(placeholder[i, k], placeholder_1[j, k])
+ T.writes(T_matmul_NT_global[i, j])
+ with T.init():
+ T_matmul_NT_global[i, j] = T.float32(0)
+ T_matmul_NT_global[i, j] = T_matmul_NT_global[i, j] +
placeholder[i, k] * placeholder_1[j, k]
+ for ax0 in T.serial(64):
+ for ax1_fused in T.vectorized(16):
+ with T.block("T_matmul_NT_global"):
+ v0 = T.axis.spatial(64, ax0)
+ v1 = T.axis.spatial(768, i1_0 * 48 + i1_1 * 16 + ax1_fused)
+ T.reads(T_matmul_NT_global[v0, v1])
+ T.writes(T_matmul_NT[v0, v1])
+ T_matmul_NT[v0, v1] = T_matmul_NT_global[v0, v1]
+
+
# fmt: on
# pylint:
enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable
@@ -78,10 +156,17 @@ def
test_meta_schedule_postproc_rewrite_parallel_unroll_vectorize():
postproc = RewriteParallelVectorizeUnroll()
sch = Schedule(Move_PUV)
assert postproc.apply(sch)
- print(sch.mod["main"].script())
mod = tvm.tir.transform.Simplify()(sch.mod)
tvm.ir.assert_structural_equal(mod["main"], Move_PUV0)
+def test_vectorize_inner_loop():
+ sch = Schedule(before_matmul_vectorize)
+ rule = RewriteParallelVectorizeUnroll()
+ assert rule.apply(sch)
+ tvm.ir.assert_structural_equal(sch.mod["main"], after_matmul_vectorize)
+
+
if __name__ == "__main__":
test_meta_schedule_postproc_rewrite_parallel_unroll_vectorize()
+ test_vectorize_inner_loop()