This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 328122675d [TIR][TOPI][x86][CI] Support skylake avx512 (#13621)
328122675d is described below
commit 328122675da7800944211e7ac0b21b3ed9398060
Author: Valery Chernov <[email protected]>
AuthorDate: Wed Jan 18 01:26:28 2023 +0400
[TIR][TOPI][x86][CI] Support skylake avx512 (#13621)
* add skylake-avx512 tests
* extend tests by skylake-avx512
* lint fixes
* fix misprinting
* misprinting fix
* TODOs for further development
* add temporally commented tests for skylake-avx512 due to not implemented
shedules and postprocs for it. add TODOs for further check and development
* update int8-acc32 test for vnni and avx512 w/o it
* pylint fix
* once more pylint fix
* fix Feature init for skylake
* fix test
* fix intrin names for assert for skylake
* small fix
* return back fast int8 intrinsic tests
* test connect of dense and batch_matmul to avx512 tensorization
* extend dense_alter_layout on avx512 (currently) instead of VNNI. some
renaming vnni to int8 for the sake of clarity
* more renaming vnni to int8 for dense schedule, compute, strategy for the
sake of clarity
* update for batch_matmul with avx512
* extend space generator init for avx512. Add Default AVX512 schedule rules
* avx512 dot 16x4 intrin was implemented for MS default schedule rule
* small fix
* update
* pylint fixes
* test workaround for const alloc in tir
* test fix (broadcasting)
* remove excess instructions from dot_product_16x4_u8i8i32_avx512
* pylint fix
* skip asm check for askew weight shapes
* fix pylint
* revert test fix
* set number of args
* test fix
* fix const allocation in tir for avx512 dot 16x4
* fix signature of dot_product_16x4_u8i8i32_avx512
* use script instead of tvm.tir for const allocation
* extend auto tensorize test by skylake-avx512 target
* clean code
* update test_op_level1, resolve TODO
* small update test_op_level2
* update test_op_level10, resolve TODO
* update qnn legalize pass test, resolve TODOs
* pylint fixes
* update ms test for avx512
* update more ms test for avx512
* try to fix i386 CI tests
* fix intrin name for check
* skip test due to model downloading issue
* fix test failure
* use ORT for conv2d check
* lint fix after rebasing
* comment ORT part of test
* extend tests tir schedule analysis and transform for avx512. unify test
classes
* extend test tir schedule tensorize for avx512
* extend test meta schedule vnni integration for avx512
* rename test file
* pylint fix
* tag fix
* update test meta schedule trace apply with avx512
* rollback test class unifying in utils
* pylint fixes
* separate TIRs for scheduled conv2d for vnni and avx512
* fix registering issue in test
* update conv+bias onnx model for intermediate test
* fix int16 overflow
* fix int16 overflow for dense test
* update input data for test of dense
* small rollback
* fix misprinting
* fix
* restart CI
* DefaultVNNI was renamed to DefaultLLVM for mutator
* rename test file for the sake of clarity
* DefaultVNNI was renamed to DefaultCPUTensorization for postproc
* remove resolved TODO
* DefaultVNNI and AVX512 for ScheduleRule were unified
* replace code to upstream with initial version
* fix arg type
* lint fix
* small fix
* lint fix
* fix misprinting
* rollback trace apply test for avx512 (reviewer remark)
* fix pylint
Co-authored-by: Valery Chernov <[email protected]>
---
include/tvm/meta_schedule/mutator.h | 2 -
include/tvm/meta_schedule/postproc.h | 4 +-
include/tvm/meta_schedule/schedule_rule.h | 4 +-
python/tvm/relay/qnn/op/legalizations.py | 4 +-
python/tvm/testing/utils.py | 29 ++++
python/tvm/tir/tensor_intrin/x86.py | 40 ++++++
python/tvm/topi/x86/batch_matmul.py | 25 ++--
python/tvm/topi/x86/dense.py | 19 +--
python/tvm/topi/x86/dense_alter_op.py | 18 +--
src/meta_schedule/mutator/mutator.cc | 2 -
src/meta_schedule/postproc/postproc.cc | 2 +-
src/meta_schedule/schedule_rule/schedule_rule.cc | 6 +-
.../space_generator/space_generator.cc | 19 ++-
tests/python/contrib/test_gemm_acc32_vnni.py | 160 ++++++++++-----------
tests/python/integration/test_auto_tensorize.py | 136 +++++++++++-------
tests/python/relay/test_op_level1.py | 24 +++-
tests/python/relay/test_op_level10.py | 45 ++++--
tests/python/relay/test_op_level2.py | 24 ++--
tests/python/relay/test_pass_qnn_legalize.py | 26 ++--
...on.py => test_meta_schedule_cpu_dot_product.py} | 62 +++++---
.../test_meta_schedule_relay_integration.py | 19 ++-
.../test_meta_schedule_schedule_rule_mlt_intrin.py | 23 +--
.../unittest/test_meta_schedule_trace_apply.py | 8 +-
.../python/unittest/test_tir_schedule_analysis.py | 15 +-
.../python/unittest/test_tir_schedule_tensorize.py | 14 +-
.../python/unittest/test_tir_schedule_transform.py | 38 ++---
26 files changed, 485 insertions(+), 283 deletions(-)
diff --git a/include/tvm/meta_schedule/mutator.h
b/include/tvm/meta_schedule/mutator.h
index 498b2797ad..1560c00f39 100644
--- a/include/tvm/meta_schedule/mutator.h
+++ b/include/tvm/meta_schedule/mutator.h
@@ -131,8 +131,6 @@ class Mutator : public runtime::ObjectRef {
FApply f_apply, FClone f_clone, FAsString
f_as_string);
/*! \brief Create default mutators for LLVM */
TVM_DLL static Map<Mutator, FloatImm, void> DefaultLLVM();
- /*! \brief Create default mutators for x86 VNNI */
- TVM_DLL static Map<Mutator, FloatImm, void> DefaultVNNI();
/*! \brief Create default mutators for CUDA */
TVM_DLL static Map<Mutator, FloatImm, void> DefaultCUDA();
/*! \brief Create default mutators for CUDA with TensorCore */
diff --git a/include/tvm/meta_schedule/postproc.h
b/include/tvm/meta_schedule/postproc.h
index 06fa086c4b..85fb9003e8 100644
--- a/include/tvm/meta_schedule/postproc.h
+++ b/include/tvm/meta_schedule/postproc.h
@@ -163,8 +163,8 @@ class Postproc : public runtime::ObjectRef {
TVM_DLL static Postproc RewriteLayout();
/*! \brief Create default postprocessors for LLVM */
TVM_DLL static Array<Postproc, void> DefaultLLVM();
- /*! \brief Create default postprocessors for x86 VNNI */
- TVM_DLL static Array<Postproc, void> DefaultVNNI();
+ /*! \brief Create default postprocessors for x86 (AVX512 and VNNI) */
+ TVM_DLL static Array<Postproc, void> DefaultCPUTensorization();
/*! \brief Create default postprocessors for CUDA */
TVM_DLL static Array<Postproc, void> DefaultCUDA();
/*! \brief Create default postprocessors for CUDA with TensorCore */
diff --git a/include/tvm/meta_schedule/schedule_rule.h
b/include/tvm/meta_schedule/schedule_rule.h
index 16202e18bf..7995d1fcee 100644
--- a/include/tvm/meta_schedule/schedule_rule.h
+++ b/include/tvm/meta_schedule/schedule_rule.h
@@ -290,8 +290,8 @@ class ScheduleRule : public runtime::ObjectRef {
/*! \brief Create default schedule rules for LLVM */
TVM_DLL static Array<ScheduleRule, void> DefaultLLVM();
- /*! \brief Create default schedule rules for x86 VNNI */
- TVM_DLL static Array<ScheduleRule, void> DefaultVNNI();
+ /*! \brief Create default schedule rules for x86 (AVX512 and VNNI) */
+ TVM_DLL static Array<ScheduleRule, void> DefaultX86(const String& type);
/*! \brief Create default schedule rules for CUDA */
TVM_DLL static Array<ScheduleRule, void> DefaultCUDA();
/*! \brief Create default postprocessors for CUDA with TensorCore */
diff --git a/python/tvm/relay/qnn/op/legalizations.py
b/python/tvm/relay/qnn/op/legalizations.py
index 9baabf36a9..ef368a016e 100644
--- a/python/tvm/relay/qnn/op/legalizations.py
+++ b/python/tvm/relay/qnn/op/legalizations.py
@@ -248,7 +248,7 @@ def helper_change_dtypes_to_uint8_int8(attrs, inputs,
types, relay_op):
Replacing QA + 128 with QA' and (zp_a + 128) with zp_a'
We get our new quantized uint8 tensor - scale * (QA' - zp_a')
- Similarly we can convert from int8 to uint8.
+ Similarly we can convert from uint8 to int8.
Parameters
----------
@@ -449,6 +449,7 @@ def _qnn_dense_legalize_arm_cpu(attrs, inputs, types):
@qnn_conv2d_legalize.register("cpu")
def _qnn_conv2d_legalize_intel_cpu(attrs, inputs, types):
+ # TODO(vvchernov): not only VNNI
# The VNNI transformations prefer uint8 x int8 datatypes.
if is_fast_int8_on_intel():
return helper_change_dtypes_to_uint8_int8(attrs, inputs, types,
relay.qnn.op.conv2d)
@@ -457,6 +458,7 @@ def _qnn_conv2d_legalize_intel_cpu(attrs, inputs, types):
@qnn_dense_legalize.register("cpu")
def _qnn_dense_legalize_intel_cpu(attrs, inputs, types):
+ # TODO(vvchernov): not only VNNI
# The VNNI transformations prefer uint8 x int8 datatypes.
if is_fast_int8_on_intel():
return helper_change_dtypes_to_uint8_int8(attrs, inputs, types,
relay.qnn.op.dense)
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index 899b054403..19669cd60c 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -1027,6 +1027,28 @@ def _has_vnni():
return False
+# check avx512 intrinsic groups for SkyLake X
+def _has_slavx512():
+ # Check LLVM support
+ llvm_version = tvm.target.codegen.llvm_version_major()
+ is_llvm_support = llvm_version >= 8
+ arch = platform.machine()
+ # Only linux is supported for now.
+ if arch == "x86_64" and sys.platform.startswith("linux"):
+ with open("/proc/cpuinfo", "r") as content:
+ ctx = content.read()
+ check = (
+ "avx512f" in ctx
+ and "avx512cd" in ctx
+ and "avx512bw" in ctx
+ and "avx512dq" in ctx
+ and "avx512vl" in ctx
+ )
+ return check and is_llvm_support
+
+ return False
+
+
requires_arm_dot = Feature("arm_dot", "ARM dot product",
run_time_check=_arm_dot_supported)
@@ -1035,6 +1057,13 @@ requires_cascadelake = Feature(
)
+requires_skylake_avx512 = Feature(
+ "skylake_avx512",
+ "x86 SkyLake AVX512",
+ run_time_check=lambda: _has_slavx512() and _is_intel(),
+)
+
+
def _cmake_flag_enabled(flag):
flag = tvm.support.libinfo()[flag]
diff --git a/python/tvm/tir/tensor_intrin/x86.py
b/python/tvm/tir/tensor_intrin/x86.py
index d93167f9e6..c527d0d210 100644
--- a/python/tvm/tir/tensor_intrin/x86.py
+++ b/python/tvm/tir/tensor_intrin/x86.py
@@ -67,8 +67,48 @@ def dot_product_16x4_u8i8i32_vnni(
)
[email protected]_func
+def dot_product_16x4_u8i8i32_avx512(
+ A: T.Buffer((4,), "uint8", offset_factor=1),
+ B: T.Buffer((16, 4), "int8", offset_factor=1),
+ C: T.Buffer((16,), "int32", offset_factor=1),
+) -> None:
+ with T.block("root"):
+ T.reads(C[0:16], A[0:4], B[0:16, 0:4])
+ T.writes(C[0:16])
+
+ A_u8x4 = A.vload([0], "uint8x4")
+ A_i32 = T.reinterpret(A_u8x4, dtype="int32")
+ A_brdcst = T.broadcast(A_i32, 16)
+ A_u8x64 = T.reinterpret(A_brdcst, dtype="uint8x64")
+
+ B_i8x64 = B.vload([0, 0], dtype="int8x64")
+
+ Red = T.call_llvm_pure_intrin(
+ T.llvm_lookup_intrinsic_id("llvm.x86.avx512.pmaddubs.w.512"),
+ T.uint32(2),
+ A_u8x64,
+ B_i8x64,
+ dtype="int16x32",
+ )
+
+ C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin(
+ T.llvm_lookup_intrinsic_id("llvm.x86.avx512.pmaddw.d.512"),
+ T.uint32(2),
+ Red,
+ T.int16x32(1),
+ dtype="int32x16",
+ )
+
+
VNNI_DOT_16x4_INTRIN = "dot_16x4_vnni"
TensorIntrin.register(
VNNI_DOT_16x4_INTRIN, dot_product_16x4_u8i8i32_desc,
dot_product_16x4_u8i8i32_vnni
)
+
+AVX512_DOT_16x4_INTRIN = "dot_16x4_avx512"
+
+TensorIntrin.register(
+ AVX512_DOT_16x4_INTRIN, dot_product_16x4_u8i8i32_desc,
dot_product_16x4_u8i8i32_avx512
+)
diff --git a/python/tvm/topi/x86/batch_matmul.py
b/python/tvm/topi/x86/batch_matmul.py
index 9f3bc29515..95408a924f 100644
--- a/python/tvm/topi/x86/batch_matmul.py
+++ b/python/tvm/topi/x86/batch_matmul.py
@@ -25,12 +25,12 @@ from tvm.contrib import cblas, mkl
from .. import generic, nn
from ..transform import layout_transform
from ..utils import get_const_tuple, get_max_power2_factor, traverse_inline
-from .dense import dense_vnni_schedule, dense_amx_int8_schedule
+from .dense import dense_int8_schedule, dense_amx_int8_schedule
from .injective import schedule_injective_from_existing
-from .utils import target_has_vnni, target_has_amx
+from .utils import target_has_avx512, target_has_amx
[email protected]_topi_compute("batch_matmul_vnni.x86")
[email protected]_topi_compute("batch_matmul_int8.x86")
def batch_matmul_int8_compute(cfg, x, y, *_):
"""Compute for uint8 x int8 -> int32 batch_matmul"""
batch, m, k = x.shape
@@ -39,8 +39,8 @@ def batch_matmul_int8_compute(cfg, x, y, *_):
_, n_o, _, n_i, _ = packed_y.shape
ak = te.reduce_axis((0, k), name="k")
mcpu = tvm.target.Target.current().mcpu
- if target_has_vnni(mcpu):
- attrs_info = {"schedule_rule": "batch_matmul_vnni"}
+ if target_has_avx512(mcpu):
+ attrs_info = {"schedule_rule": "batch_matmul_int8"}
else:
attrs_info = None
@@ -60,13 +60,14 @@ def batch_matmul_int8_compute(cfg, x, y, *_):
return z
-def batch_matmul_vnni_schedule(cfg, s, C, O, layout_trans):
- """Schedule batch_matmul compute using VNNI vpdpbusd instruction"""
+def batch_matmul_int8_schedule(cfg, s, C, O, layout_trans):
+ """Schedule batch_matmul compute using avx512 or lower instructions
+ including VNNI vpdpbusd instruction if possible"""
# C: The output of batched GEMM
# O: The output of the fused op
# Schedule the GEMM part
- s, fused_inner = dense_vnni_schedule(cfg, s, C, O, do_parallel=False)
+ s, fused_inner = dense_int8_schedule(cfg, s, C, O, do_parallel=False)
# Parallelize over batch
fused = s[O].fuse(O.op.axis[0], fused_inner)
s[O].parallel(fused)
@@ -228,9 +229,9 @@ def schedule_batch_matmul(cfg, outs):
return s
[email protected]_topi_schedule("batch_matmul_vnni.x86")
[email protected]_topi_schedule("batch_matmul_int8.x86")
def schedule_batch_matmul_int8(cfg, outs):
- """Schedule for batch_matmul_vnni"""
+ """Schedule for batch_matmul_int8"""
s = te.create_schedule([x.op for x in outs])
mcpu = tvm.target.Target.current().mcpu
@@ -239,8 +240,8 @@ def schedule_batch_matmul_int8(cfg, outs):
layout_trans = op.input_tensors[1]
if target_has_amx(mcpu):
batch_matmul_amx_schedule(cfg, s, op.output(0), outs[0],
layout_trans)
- elif target_has_vnni(mcpu):
- batch_matmul_vnni_schedule(cfg, s, op.output(0), outs[0],
layout_trans)
+ elif target_has_avx512(mcpu):
+ batch_matmul_int8_schedule(cfg, s, op.output(0), outs[0],
layout_trans)
traverse_inline(s, outs[0].op, _callback)
return s
diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py
index bb99a63281..b697cf98a6 100644
--- a/python/tvm/topi/x86/dense.py
+++ b/python/tvm/topi/x86/dense.py
@@ -26,10 +26,10 @@ from tvm.contrib import cblas, dnnl, mkl
from .. import generic, tag
from ..utils import get_const_tuple, traverse_inline
-from .tensor_intrin import dot_16x1x16_uint8_int8_int32_cascadelake
+from .tensor_intrin import dot_16x1x16_uint8_int8_int32
from .tensor_intrin import dot_32x128x32_u8s8s32_sapphirerapids
from .tensor_intrin import acc_32x32_int32_sapphirerapids
-from .utils import get_simd_32bit_lanes, target_has_vnni, target_has_amx
+from .utils import get_simd_32bit_lanes, target_has_avx512, target_has_amx
def _schedule_dense_pack_template(cfg, s, C, O):
@@ -302,8 +302,8 @@ def schedule_dense_int8(cfg, outs):
if "dense_int8" in op.tag:
if target_has_amx(mcpu):
dense_amx_int8_schedule(cfg, s, op.output(0), outs[0])
- elif target_has_vnni(mcpu):
- dense_vnni_schedule(cfg, s, op.output(0), outs[0])
+ elif target_has_avx512(mcpu):
+ dense_int8_schedule(cfg, s, op.output(0), outs[0])
traverse_inline(s, outs[0].op, _callback)
return s
@@ -315,8 +315,8 @@ def dense_int8_compute(cfg, X, packed_w, bias=None):
n_o, _, n_i, _ = packed_w.shape
ak = te.reduce_axis((0, k), name="k")
mcpu = tvm.target.Target.current().mcpu
- if target_has_vnni(mcpu):
- target_attr = {"schedule_rule": "meta_schedule.x86.dense_vnni"}
+ if target_has_avx512(mcpu):
+ target_attr = {"schedule_rule": "meta_schedule.x86.dense_int8"}
else:
target_attr = None
@@ -339,8 +339,9 @@ def dense_int8_compute(cfg, X, packed_w, bias=None):
return C
-def dense_vnni_schedule(cfg, s, C, O, do_parallel=True):
- """Schedule dense compute using VNNI vpdpbusd instruction"""
+def dense_int8_schedule(cfg, s, C, O, do_parallel=True):
+ """Schedule dense compute using avx512 or lower instructions
+ including VNNI vpdpbusd instruction if possible"""
# C: The output of GEMM
# O: The output of the fused op
def split_y(out):
@@ -361,7 +362,7 @@ def dense_vnni_schedule(cfg, s, C, O, do_parallel=True):
s[C].reorder(a_yo, a_xo, a_yi, a_ko, a_xi, a_ki)
- pc = dot_16x1x16_uint8_int8_int32_cascadelake()
+ pc = dot_16x1x16_uint8_int8_int32()
s[C].tensorize(a_xi, pc)
if C == O:
diff --git a/python/tvm/topi/x86/dense_alter_op.py
b/python/tvm/topi/x86/dense_alter_op.py
index 2cb46b8291..a380b7fc9f 100644
--- a/python/tvm/topi/x86/dense_alter_op.py
+++ b/python/tvm/topi/x86/dense_alter_op.py
@@ -24,14 +24,14 @@ from tvm import autotvm
from .dense import _default_dense_pack_config
from ..utils import get_const_tuple
from ..nn import dense_alter_layout
-from .utils import target_has_vnni
-from .utils import target_has_amx
+from .utils import target_has_avx512, target_has_amx
from .. import nn
-def check_inst_applicable(x, y, allow_padding=False):
+def check_int8_applicable(x, y, allow_padding=False):
mcpu = tvm.target.Target.current().mcpu
- simd_avai = target_has_vnni(mcpu) or target_has_amx(mcpu)
+ # TODO(vvchernov): may be also target_has_avx2 or lower?
+ simd_avai = target_has_avx512(mcpu) or target_has_amx(mcpu)
return (
simd_avai
and "int8" in x.dtype
@@ -49,7 +49,7 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type):
M, K = get_const_tuple(data_tensor.shape)
N, _ = get_const_tuple(weight_tensor.shape)
- if check_inst_applicable(data_tensor, weight_tensor) and data_tensor.dtype
== "uint8":
+ if check_int8_applicable(data_tensor, weight_tensor) and data_tensor.dtype
== "uint8":
weight_layout = "NC16n4c"
return relay.nn.contrib_dense_pack(inputs[0], inputs[1],
weight_layout, None, out_dtype)
@@ -86,10 +86,10 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type):
return None
-def vnni_legalize(inputs, arg_types, op, attrs, need_expand=False):
+def int8_int8_legalize(inputs, arg_types, op, attrs, need_expand=False):
"""Legalizes s8, s8 -> s32 GEMM op for VNNI."""
if (
- check_inst_applicable(arg_types[0], arg_types[1], allow_padding=True)
+ check_int8_applicable(arg_types[0], arg_types[1], allow_padding=True)
and arg_types[0].dtype == "int8"
):
x, y = inputs
@@ -135,7 +135,7 @@ def vnni_legalize(inputs, arg_types, op, attrs,
need_expand=False):
@nn.dense_legalize.register("cpu")
def _dense_legalize(attrs, inputs, arg_types):
"""Legalizes s8, s8 -> s32 dense for VNNI."""
- return vnni_legalize(inputs, arg_types, relay.nn.dense, attrs)
+ return int8_int8_legalize(inputs, arg_types, relay.nn.dense, attrs)
@nn.batch_matmul_legalize.register("cpu")
@@ -143,4 +143,4 @@ def _batch_matmul_legalize(attrs, inputs, arg_types):
"""Legalizes s8, s8 -> s32 batch_matmul for VNNI."""
if attrs["transpose_a"] or not attrs["transpose_b"]:
return None
- return vnni_legalize(inputs, arg_types, relay.nn.batch_matmul, attrs,
need_expand=True)
+ return int8_int8_legalize(inputs, arg_types, relay.nn.batch_matmul, attrs,
need_expand=True)
diff --git a/src/meta_schedule/mutator/mutator.cc
b/src/meta_schedule/mutator/mutator.cc
index 3cf43e1126..ddc2d73590 100644
--- a/src/meta_schedule/mutator/mutator.cc
+++ b/src/meta_schedule/mutator/mutator.cc
@@ -59,8 +59,6 @@ Map<Mutator, FloatImm> Mutator::DefaultLLVM() {
{Mutator::MutateParallel(/*max_jobs_per_core=*/16),
FloatImm(DataType::Float(64), 0.02)}};
}
-Map<Mutator, FloatImm> Mutator::DefaultVNNI() { return Mutator::DefaultLLVM();
}
-
Map<Mutator, FloatImm> Mutator::DefaultCUDA() {
return Map<Mutator, FloatImm>{
{Mutator::MutateTileSize(), FloatImm(DataType::Float(64), 0.9)},
diff --git a/src/meta_schedule/postproc/postproc.cc
b/src/meta_schedule/postproc/postproc.cc
index 7730e4372f..bcd0cef4dd 100644
--- a/src/meta_schedule/postproc/postproc.cc
+++ b/src/meta_schedule/postproc/postproc.cc
@@ -59,7 +59,7 @@ Array<Postproc> Postproc::DefaultLLVM() {
};
}
-Array<Postproc> Postproc::DefaultVNNI() {
+Array<Postproc> Postproc::DefaultCPUTensorization() {
return Array<Postproc>{
Postproc::DisallowDynamicLoop(),
Postproc::RewriteParallelVectorizeUnroll(),
Postproc::RewriteReductionBlock(),
Postproc::RewriteTensorize(/*vectorize_init_loop=*/true),
diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc
b/src/meta_schedule/schedule_rule/schedule_rule.cc
index 1137032720..e25f0b1221 100644
--- a/src/meta_schedule/schedule_rule/schedule_rule.cc
+++ b/src/meta_schedule/schedule_rule/schedule_rule.cc
@@ -85,7 +85,9 @@ Array<ScheduleRule> ScheduleRule::DefaultLLVM() {
};
}
-Array<ScheduleRule> ScheduleRule::DefaultVNNI() {
+Array<ScheduleRule> ScheduleRule::DefaultX86(const String& type) {
+ static const Map<String, String> intrins = {{"vnni", "dot_16x4_vnni"},
+ {"avx512", "dot_16x4_avx512"}};
return {
ScheduleRule::ApplyCustomRule(),
ScheduleRule::InlineConstantScalars(),
@@ -101,7 +103,7 @@ Array<ScheduleRule> ScheduleRule::DefaultVNNI() {
/*max_jobs_per_core=*/16,
/*max_innermost_factor=*/Integer(64)),
ScheduleRule::MultiLevelTilingWithIntrin(
- /*intrin_name=*/"dot_16x4_vnni",
+ /*intrin_name=*/intrins[type],
/*structure=*/"SSRSRS",
/*tile_binds=*/NullOpt,
/*max_innermost_factor=*/Integer(64),
diff --git a/src/meta_schedule/space_generator/space_generator.cc
b/src/meta_schedule/space_generator/space_generator.cc
index 926f86cc4f..2ce8d8fa11 100644
--- a/src/meta_schedule/space_generator/space_generator.cc
+++ b/src/meta_schedule/space_generator/space_generator.cc
@@ -29,6 +29,14 @@ String GetRuleKindFromTarget(const Target& target) {
if (target->GetAttr<String>("mcpu") &&
(*f_check_vnni)(target->GetAttr<String>("mcpu").value())) {
return "vnni";
+ } else {
+ static const PackedFunc* f_check_avx512 =
+ runtime::Registry::Get("tvm.topi.x86.utils.target_has_avx512");
+ ICHECK(f_check_avx512 != nullptr) << "The `target_has_avx512` func is
not in tvm registry.";
+ if (target->GetAttr<String>("mcpu") &&
+ (*f_check_avx512)(target->GetAttr<String>("mcpu").value())) {
+ return "avx512";
+ }
}
return "llvm";
}
@@ -73,6 +81,7 @@ void SpaceGeneratorNode::InitializeWithTuneContext(const
TuneContext& context) {
Array<ScheduleRule> default_sch_rules;
Array<Postproc> default_postprocs;
Map<Mutator, FloatImm> default_mutator_probs;
+ // for target with skylake-avx512
if (kind == "llvm") {
default_sch_rules = ScheduleRule::DefaultLLVM();
default_postprocs = Postproc::DefaultLLVM();
@@ -90,9 +99,13 @@ void SpaceGeneratorNode::InitializeWithTuneContext(const
TuneContext& context) {
default_postprocs = Postproc::DefaultHexagon();
default_mutator_probs = Mutator::DefaultHexagon();
} else if (kind == "vnni") {
- default_sch_rules = ScheduleRule::DefaultVNNI();
- default_postprocs = Postproc::DefaultVNNI();
- default_mutator_probs = Mutator::DefaultVNNI();
+ default_sch_rules = ScheduleRule::DefaultX86("vnni");
+ default_postprocs = Postproc::DefaultCPUTensorization();
+ default_mutator_probs = Mutator::DefaultLLVM();
+ } else if (kind == "avx512") {
+ default_sch_rules = ScheduleRule::DefaultX86("avx512");
+ default_postprocs = Postproc::DefaultCPUTensorization();
+ default_mutator_probs = Mutator::DefaultLLVM();
} else if (kind == "c") {
default_sch_rules = ScheduleRule::DefaultMicro();
default_postprocs = Postproc::DefaultMicro();
diff --git a/tests/python/contrib/test_gemm_acc32_vnni.py
b/tests/python/contrib/test_gemm_acc32_vnni.py
index 9cec823cc5..c01f7758cb 100644
--- a/tests/python/contrib/test_gemm_acc32_vnni.py
+++ b/tests/python/contrib/test_gemm_acc32_vnni.py
@@ -14,106 +14,102 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines,
len-as-condition
import tvm
import tvm.testing
from tvm import te
import numpy as np
-from tvm.topi.x86.tensor_intrin import dot_16x1x16_uint8_int8_int32_cascadelake
from tvm.topi.x86.tensor_intrin import dot_16x1x16_uint8_int8_int32
-import pytest
[email protected]_llvm
[email protected]("skip because feature not enabled")
-def test_fc_int8_acc32():
- m = 1024
- n = 1024
- k = 1024
-
+def verify_fc_int8_acc32(m=1024, n=1024, k=1024, target="llvm
-mcpu=cascadelake"):
X = te.placeholder((m, k), name="X", dtype="uint8")
- W = te.placeholder((n, k), name="W", dtype="int8")
+ # W = te.placeholder((n, k), name="W", dtype="int8")
+
+ if not tvm.testing.device_enabled(target):
+ print("skip because %s is not enabled..." % target)
+ return
+
+ dev = tvm.device(target, 0)
+ # workaround for Target.current()
+ with tvm.target.Target(target) as target:
+ pc = dot_16x1x16_uint8_int8_int32()
+
+ ak = te.reduce_axis((0, k), name="k")
+ packedW = te.placeholder((n // 16, 16 * (k // 4), 4), name="packedW",
dtype="int8")
+
+ t_fc = te.compute(
+ (m, n),
+ lambda i, j: te.sum(
+ X[i, ak].astype("int32")
+ * packedW[
+ tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4) * 16 + j %
16, ak % 4
+ ].astype("int32"),
+ axis=ak,
+ ),
+ name="F",
+ )
+ t_sch = te.create_schedule(t_fc.op)
+ a_x, a_y = t_fc.op.axis
+ (a_k,) = t_fc.op.reduce_axis
+
+ a_yo, a_yi = t_sch[t_fc].split(a_y, factor=16)
+ a_xo, a_xi = t_sch[t_fc].split(a_x, factor=32)
+ a_ko, a_ki = t_sch[t_fc].split(a_k, factor=4)
+ a_koo, a_koi = t_sch[t_fc].split(a_ko, factor=4)
+ t_sch[t_fc].reorder(a_yo, a_xo, a_xi, a_koo, a_koi, a_yi, a_ki)
+
+ t_sch[t_fc].unroll(a_koi)
+ t_sch[t_fc].tensorize(a_yi, pc)
+
+ t_func = tvm.build(t_sch, [X, packedW, t_fc], target, name="intrinsic")
+ t_evaluator = t_func.time_evaluator(t_func.entry_name, dev, number=10)
+
+ # generate the plain data
+ a_ = np.random.uniform(1, 10, size=(m, k)).astype("uint8")
+ b_ = np.random.uniform(1, 10, size=(n, k)).astype("int8")
+
+ packW = np.random.uniform(1, 10, size=(n // 16, 16 * (k // 4),
4)).astype("int8")
+ # This occurs in pre_compute stage
+ for r_idx in range(n // 16):
+ for s_idx in range(16 * (k // 4)):
+ for t_idx in range(4):
+ packW[r_idx][s_idx][t_idx] = b_[r_idx * 16 + s_idx %
16][(s_idx // 16) * 4 + t_idx]
+
+ x = tvm.nd.array(a_, dev)
+ w = tvm.nd.array(packW, dev)
+ y = tvm.nd.array(np.zeros((m, n), dtype="int32"), dev)
+ result = t_evaluator(x, w, y)
peak = 280
print("Peak {} Gops/s".format(peak))
- memory_ops = m * k + n * k + 2 * m * n
+ # memory_ops = m * k + n * k + 2 * m * n
gops_per_mm = 2 * m * n * k
+ gops_per_sec = gops_per_mm / result.mean / 1e9
+ # verify the correctness
+ tvm.testing.assert_allclose(y.numpy(), np.dot(a_, b_.T), rtol=0)
+ print(
+ "Tensorization: running time: {:.3f} ms, {:.2f} Gops/s, effiency:
{:.2f}".format(
+ result.mean * 1000, gops_per_sec, gops_per_sec / peak
+ )
+ )
+ # t_func.export_library("tensorize_acc32.o")
+
+
[email protected]_cascadelake
+def test_fc_int8_acc32_vnni():
# For LLVM < 8.0, it shows "'cascadelake' is not a recognized processor
for this target
# (ignoring processor)" error with the following setting. After LLVM 8.0
is enabled in the
# test, we should use cascadelake setting.
- def verify(target="llvm -mcpu=cascadelake"):
- if not tvm.testing.device_enabled(target):
- print("skip because %s is not enabled..." % target)
- return
-
- dev = tvm.device(target, 0)
- pc = dot_16x1x16_uint8_int8_int32_cascadelake()
- ak = te.reduce_axis((0, k), name="k")
- packedW = te.placeholder((n // 16, 16 * (k // 4), 4), name="packedW",
dtype="int8")
-
- t_fc = te.compute(
- (m, n),
- lambda i, j: te.sum(
- X[i, ak].astype("int32")
- * packedW[
- tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4) * 16 + j
% 16, ak % 4
- ].astype("int32"),
- axis=ak,
- ),
- name="F",
- )
- t_sch = te.create_schedule(t_fc.op)
- a_x, a_y = t_fc.op.axis
- (a_k,) = t_fc.op.reduce_axis
-
- a_yo, a_yi = t_sch[t_fc].split(a_y, factor=16)
- a_xo, a_xi = t_sch[t_fc].split(a_x, factor=32)
- a_ko, a_ki = t_sch[t_fc].split(a_k, factor=4)
- a_koo, a_koi = t_sch[t_fc].split(a_ko, factor=4)
- t_sch[t_fc].reorder(a_yo, a_xo, a_xi, a_koo, a_koi, a_yi, a_ki)
-
- t_sch[t_fc].unroll(a_koi)
- t_sch[t_fc].tensorize(a_yi, pc)
-
- t_func = tvm.build(t_sch, [X, packedW, t_fc], target, name="intrinsic")
- t_evaluator = t_func.time_evaluator(t_func.entry_name, dev, number=10)
-
- # generate the plain data
- a_ = np.random.uniform(1, 10, size=(m, k)).astype("uint8")
- b_ = np.random.uniform(1, 10, size=(n, k)).astype("int8")
-
- packW = np.random.uniform(1, 10, size=(n // 16, 16 * (k // 4),
4)).astype("int8")
- # This occurs in pre_compute stage
- for r_idx in range(n // 16):
- for s_idx in range(16 * (k // 4)):
- for t_idx in range(4):
- packW[r_idx][s_idx][t_idx] = b_[r_idx * 16 + s_idx % 16][
- (s_idx // 16) * 4 + t_idx
- ]
-
- x = tvm.nd.array(a_, dev)
- w = tvm.nd.array(packW, dev)
- y = tvm.nd.array(np.zeros((m, n), dtype="int32"), dev)
- result = t_evaluator(x, w, y)
-
- gops_per_sec = gops_per_mm / result.mean / 1e9
- # verify the correctness
- tvm.testing.assert_allclose(y.numpy(), np.dot(a_, b_.T), rtol=0)
- print(
- "Tensorization: running time: {:.3f} ms, {:.2f} Gops/s, effiency:
{:.2f}".format(
- result.mean * 1000, gops_per_sec, gops_per_sec / peak
- )
- )
- t_func.export_library("tensorize_acc32.o")
+ verify_fc_int8_acc32()
- verify()
[email protected]_skylake_avx512
+def test_fc_int8_acc32_avx512():
+ verify_fc_int8_acc32(target="llvm -mcpu=skylake-avx512")
-if __name__ == "__main__":
- # The test requires Cascade Lake and newer Intel machines to generate the
- # correct AVX512 VNNI instruction. So, disabling the test.
- # test_fc_int8_acc32()
- pass
+if __name__ == "__main__":
+ test_fc_int8_acc32_vnni()
+ test_fc_int8_acc32_avx512()
diff --git a/tests/python/integration/test_auto_tensorize.py
b/tests/python/integration/test_auto_tensorize.py
index 572da53b34..70b2b875c1 100644
--- a/tests/python/integration/test_auto_tensorize.py
+++ b/tests/python/integration/test_auto_tensorize.py
@@ -29,52 +29,63 @@ from tvm.meta_schedule.testing.tlcbench import
load_quantized_bert_base
from tvm.tir.tensor_intrin.arm_cpu import DP4A_INTRIN
from tvm.tir.tensor_intrin.rocm import AMDGPU_SDOT4_INTRIN
from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN
+from tvm.tir.tensor_intrin.x86 import AVX512_DOT_16x4_INTRIN as AVX512_INTRIN
-SCH_RULES_FOR_VNNI = [
- ms.schedule_rule.ApplyCustomRule(),
- ms.schedule_rule.AutoInline(
- into_producer=False,
- into_consumer=True,
- inline_const_tensor=True,
- disallow_if_then_else=True,
- require_injective=True,
- require_ordered=True,
- disallow_op=["tir.exp"],
- ),
- ms.schedule_rule.AddRFactor(max_jobs_per_core=16, max_innermost_factor=64),
- ms.schedule_rule.MultiLevelTilingWithIntrin(
- VNNI_INTRIN,
- structure="SSRSRS",
- tile_binds=None,
- max_innermost_factor=64,
- vector_load_lens=None,
- reuse_read=None,
- reuse_write=ms.schedule_rule.ReuseType(
- req="may",
- levels=[1, 2],
- scope="global",
+
+CASCADELAKE_VNNI_TARGET = "llvm -mcpu=cascadelake -num-cores 4"
+SKYLAKE_AVX512_TARGET = "llvm -mcpu=skylake-avx512 -num-cores 4"
+
+
+def _get_schedule_rules_for_x86(intrin):
+ return [
+ ms.schedule_rule.ApplyCustomRule(),
+ ms.schedule_rule.AutoInline(
+ into_producer=False,
+ into_consumer=True,
+ inline_const_tensor=True,
+ disallow_if_then_else=True,
+ require_injective=True,
+ require_ordered=True,
+ disallow_op=["tir.exp"],
+ ),
+ ms.schedule_rule.AddRFactor(max_jobs_per_core=16,
max_innermost_factor=64),
+ ms.schedule_rule.MultiLevelTilingWithIntrin(
+ intrin,
+ structure="SSRSRS",
+ tile_binds=None,
+ max_innermost_factor=64,
+ vector_load_lens=None,
+ reuse_read=None,
+ reuse_write=ms.schedule_rule.ReuseType(
+ req="may",
+ levels=[1, 2],
+ scope="global",
+ ),
+ ),
+ ms.schedule_rule.MultiLevelTiling(
+ structure="SSRSRS",
+ tile_binds=None,
+ max_innermost_factor=64,
+ vector_load_lens=None,
+ reuse_read=None,
+ reuse_write=ms.schedule_rule.ReuseType(
+ req="may",
+ levels=[1, 2],
+ scope="global",
+ ),
),
- ),
- ms.schedule_rule.MultiLevelTiling(
- structure="SSRSRS",
- tile_binds=None,
- max_innermost_factor=64,
- vector_load_lens=None,
- reuse_read=None,
- reuse_write=ms.schedule_rule.ReuseType(
- req="may",
- levels=[1, 2],
- scope="global",
+ ms.schedule_rule.ParallelizeVectorizeUnroll(
+ max_jobs_per_core=16,
+ max_vectorize_extent=64,
+ unroll_max_steps=[0, 16, 64, 512],
+ unroll_explicit=True,
),
- ),
- ms.schedule_rule.ParallelizeVectorizeUnroll(
- max_jobs_per_core=16,
- max_vectorize_extent=64,
- unroll_max_steps=[0, 16, 64, 512],
- unroll_explicit=True,
- ),
- ms.schedule_rule.RandomComputeLocation(),
-]
+ ms.schedule_rule.RandomComputeLocation(),
+ ]
+
+
+SCH_RULES_FOR_VNNI = _get_schedule_rules_for_x86(VNNI_INTRIN)
+SCH_RULES_FOR_AVX512 = _get_schedule_rules_for_x86(AVX512_INTRIN)
def _get_sch_rules_for_dp4a(intrin):
@@ -177,6 +188,11 @@ def tune_and_test(relay_mod, data_np, weight_np, op_name,
target, sch_rules, pos
asm = lib.lib.get_source("asm")
assert "vpdpbusd" in asm
+ if "skylake-avx512" in target:
+ asm = lib.lib.get_source("asm")
+ assert "pmaddubs" in asm
+ assert "pmaddw" in asm
+
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
runtime.set_input("data", data_np)
runtime.run()
@@ -273,9 +289,12 @@ def _test_bert_int8(relay_mod, params, input_info, target,
sch_rules, postprocs)
@tvm.testing.requires_cascadelake
def test_vnni_dense():
- _test_dense(
- "uint8", SCH_RULES_FOR_VNNI, POSTPROCS_FOR_VNNI, "llvm
-mcpu=cascadelake -num-cores 4"
- )
+ _test_dense("uint8", SCH_RULES_FOR_VNNI, POSTPROCS_FOR_VNNI,
CASCADELAKE_VNNI_TARGET)
+
+
[email protected]_skylake_avx512
+def test_avx512_dense():
+ _test_dense("uint8", SCH_RULES_FOR_AVX512, POSTPROCS_FOR_VNNI,
SKYLAKE_AVX512_TARGET)
@pytest.mark.skip("Only tested locally on sm_86 (for cuda) which is not
supported by CI")
@@ -293,9 +312,12 @@ def test_dp4a_dense():
@tvm.testing.requires_cascadelake
def test_vnni_conv2d():
- _test_conv2d(
- "uint8", SCH_RULES_FOR_VNNI, POSTPROCS_FOR_VNNI, "llvm
-mcpu=cascadelake -num-cores 4"
- )
+ _test_conv2d("uint8", SCH_RULES_FOR_VNNI, POSTPROCS_FOR_VNNI,
CASCADELAKE_VNNI_TARGET)
+
+
[email protected]_skylake_avx512
+def test_avx512_conv2d():
+ _test_conv2d("uint8", SCH_RULES_FOR_AVX512, POSTPROCS_FOR_VNNI,
SKYLAKE_AVX512_TARGET)
@pytest.mark.skip("Only tested locally on sm_86 (for cuda) which is not
supported by CI")
@@ -319,12 +341,26 @@ def test_vnni_bert_int8():
relay_mod,
params,
input_info,
- "llvm -mcpu=cascadelake -num-cores 4",
+ CASCADELAKE_VNNI_TARGET,
SCH_RULES_FOR_VNNI,
POSTPROCS_FOR_VNNI,
)
[email protected]_skylake_avx512
[email protected]("Due to quantized BERT download issue")
+def test_avx512_bert_int8():
+ relay_mod, params, input_info = load_quantized_bert_base()
+ _test_bert_int8(
+ relay_mod,
+ params,
+ input_info,
+ SKYLAKE_AVX512_TARGET,
+ SCH_RULES_FOR_AVX512,
+ POSTPROCS_FOR_VNNI,
+ )
+
+
@tvm.testing.requires_gpu
@pytest.mark.skip("Slow on CI")
def test_dp4a_bert_int8():
diff --git a/tests/python/relay/test_op_level1.py
b/tests/python/relay/test_op_level1.py
index 3bb9918c7c..0549f4f2fb 100644
--- a/tests/python/relay/test_op_level1.py
+++ b/tests/python/relay/test_op_level1.py
@@ -760,9 +760,7 @@ def test_bitserial_dense():
assert yy.checked_type == relay.TensorType((m, 32), "int16")
[email protected]_cascadelake
[email protected]("m,n,k", [(32, 128, 96), (32, 128, 97)])
-def test_dense_vnni(m, n, k):
+def dense_x86_test(m, n, k, target="llvm -mcpu=cascadelake",
intrins=["vpdpbusd"]):
data_shape = (m, k)
weight_shape = (n, k)
@@ -774,12 +772,14 @@ def test_dense_vnni(m, n, k):
out = relay.nn.bias_add(dense, bias)
mod = tvm.IRModule.from_expr(out)
- target = "llvm -mcpu=cascadelake"
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target)
- asm = lib.lib.get_source("asm")
- assert "vpdpbusd" in asm
+ # TODO(vvchernov): needs for avx512 arch, can be extended
+ if n % 16 == 0 and k % 4 == 0:
+ asm = lib.lib.get_source("asm")
+ for intrin in intrins:
+ assert intrin in asm
dev = tvm.device(target, 0)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
@@ -846,6 +846,18 @@ def test_dense_amx_int8():
np.testing.assert_equal(out, ref)
[email protected]_cascadelake
[email protected]("m,n,k", [(32, 128, 96), (32, 128, 97)])
+def test_dense_vnni(m, n, k):
+ dense_x86_test(m, n, k)
+
+
[email protected]_skylake_avx512
[email protected]("m,n,k", [(32, 128, 96), (32, 128, 97)])
+def test_dense_skylake_avx512(m, n, k):
+ dense_x86_test(m, n, k, "llvm -mcpu=skylake-avx512", ["pmaddubs",
"pmaddw", "vpaddd"])
+
+
@pytest.mark.skip("Requires GFX10 AMDGPU")
def test_dense_rocm_sdot4():
data_shape = (32, 96)
diff --git a/tests/python/relay/test_op_level10.py
b/tests/python/relay/test_op_level10.py
index cdf4e73484..ed044989ac 100644
--- a/tests/python/relay/test_op_level10.py
+++ b/tests/python/relay/test_op_level10.py
@@ -473,16 +473,7 @@ def test_batch_matmul(executor_kind):
verify_batch_matmul_with_inputs(executor_kind, x, x, x_np, x_np, (10, 27,
27))
[email protected]_cascadelake
[email protected](
- "b,m,n,k",
- [
- (16, 32, 128, 96),
- (16, 32, 128, 97),
- (16, 32, 129, 96),
- ],
-)
-def test_batch_matmul_vnni(b, m, n, k):
+def batch_matmul_x86_test(b, m, n, k, target="llvm -mcpu=cascadelake",
intrins=["vpdpbusd"]):
x_shape = (b, m, k)
y_shape = (b, n, k)
z_shape = (b, m, n)
@@ -495,12 +486,14 @@ def test_batch_matmul_vnni(b, m, n, k):
out = bmm + z
mod = tvm.IRModule.from_expr(out)
- target = "llvm -mcpu=cascadelake"
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target)
- asm = lib.lib.get_source("asm")
- assert "vpdpbusd" in asm
+ # TODO(vvchernov): needs for avx512 arch, can be extended
+ if n % 16 == 0 and k % 4 == 0:
+ asm = lib.lib.get_source("asm")
+ for intrin in intrins:
+ assert intrin in asm
dev = tvm.device(target, 0)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
@@ -575,6 +568,32 @@ def test_batch_matmul_amx(b, m, n, k):
np.testing.assert_equal(out, ref)
[email protected]_cascadelake
[email protected](
+ "b,m,n,k",
+ [
+ (16, 32, 128, 96),
+ (16, 32, 128, 97),
+ (16, 32, 129, 96),
+ ],
+)
+def test_batch_matmul_vnni(b, m, n, k):
+ batch_matmul_x86_test(b, m, n, k)
+
+
[email protected]_skylake_avx512
[email protected](
+ "b,m,n,k",
+ [
+ (16, 32, 128, 96),
+ (16, 32, 128, 97),
+ (16, 32, 129, 96),
+ ],
+)
+def test_batch_matmul_skylake_avx512(b, m, n, k):
+ batch_matmul_x86_test(b, m, n, k, "llvm -mcpu=skylake-avx512",
["pmaddubs", "pmaddw", "vpaddd"])
+
+
@pytest.mark.skip("Requires GFX10 AMDGPU")
def test_batch_matmul_rocm_sdot4():
x_shape = (16, 32, 96)
diff --git a/tests/python/relay/test_op_level2.py
b/tests/python/relay/test_op_level2.py
index ca1adf9400..f7cfc81fb2 100644
--- a/tests/python/relay/test_op_level2.py
+++ b/tests/python/relay/test_op_level2.py
@@ -1696,7 +1696,7 @@ class TestConv2DInt8Intrinsics:
elif "cascadelake" in target:
return "vpdpbusd"
else:
- assert False, "Target should be Skylake or Cascadelake"
+ assert False, "Target should be Nehalem or core-avx2 or Skylake or
Cascadelake"
@tvm.testing.fixture
def assembly(
@@ -2137,7 +2137,7 @@ def test_conv2d_nhwc_dnnl():
np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)
-def _test_conv2d_int8_alter_dtype(data_dtype, target, dot_product_instr):
+def _test_conv2d_int8_alter_dtype(data_dtype, target, dot_product_instrs):
def get_conv2d_nchw(
d_shape,
w_shape,
@@ -2168,16 +2168,16 @@ def _test_conv2d_int8_alter_dtype(data_dtype, target,
dot_product_instr):
bias = relay.var("bias", shape=bias_shape, dtype="int32")
bias_np = np.random.randint(low=-127, high=128,
size=bias_shape).astype("int32")
- weight_np = np.random.uniform(-128, 127, size=weight_shape).astype("int8")
+ weight_np = np.random.uniform(-32, 32, size=weight_shape).astype("int8")
conv2d = get_conv2d_nchw(data_shape, weight_shape, data_dtype)
bias_add = relay.add(conv2d, bias)
mod = tvm.IRModule.from_expr(bias_add)
if data_dtype == "uint8":
- data_np = np.random.uniform(0, 255, size=data_shape).astype("uint8")
+ data_np = np.random.uniform(0, 64, size=data_shape).astype("uint8")
else:
- data_np = np.random.uniform(-128, 127, size=data_shape).astype("int8")
+ data_np = np.random.uniform(-32, 32, size=data_shape).astype("int8")
params = {"weight": weight_np, "bias": bias_np}
@@ -2194,7 +2194,8 @@ def _test_conv2d_int8_alter_dtype(data_dtype, target,
dot_product_instr):
):
lib = relay.build(mod, target=target, params=params)
- assert dot_product_instr in lib.lib.get_source("asm")
+ for dot_product_instr in dot_product_instrs:
+ assert dot_product_instr in lib.lib.get_source("asm")
rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
@@ -2210,13 +2211,20 @@ def _test_conv2d_int8_alter_dtype(data_dtype, target,
dot_product_instr):
@tvm.testing.requires_arm_dot
def test_conv2d_int8_alter_dtype_arm():
_test_conv2d_int8_alter_dtype(
- "uint8", "llvm -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod",
"sdot"
+ "uint8", "llvm -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod",
["sdot"]
)
@tvm.testing.requires_cascadelake
def test_conv2d_int8_alter_dtype_vnni():
- _test_conv2d_int8_alter_dtype("int8", "llvm -mcpu=cascadelake", "vpdpbusd")
+ _test_conv2d_int8_alter_dtype("int8", "llvm -mcpu=cascadelake",
["vpdpbusd"])
+
+
[email protected]_skylake_avx512
+def test_conv2d_int8_alter_dtype_avx512():
+ _test_conv2d_int8_alter_dtype(
+ "int8", "llvm -mcpu=skylake-avx512", ["pmaddubs", "pmaddw", "vpaddd"]
+ )
if __name__ == "__main__":
diff --git a/tests/python/relay/test_pass_qnn_legalize.py
b/tests/python/relay/test_pass_qnn_legalize.py
index a30cd1e73e..c64b30a212 100644
--- a/tests/python/relay/test_pass_qnn_legalize.py
+++ b/tests/python/relay/test_pass_qnn_legalize.py
@@ -136,11 +136,12 @@ def test_qnn_legalize_qnn_conv2d():
#############################################################
# Check transformations for platforms with fast Int8 support.
#############################################################
- # Check that Intel VNNI gets picked up.
- with tvm.target.Target("llvm -mcpu=skylake-avx512"):
- mod = relay.transform.InferType()(mod)
- legalized_mod = relay.qnn.transform.Legalize()(mod)
- assert "cast" in legalized_mod.astext() and "qnn.conv2d" in
legalized_mod.astext()
+ # Check that Intel AVX512 (with or w/o VNNI) gets picked up.
+ for target in ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"]:
+ with tvm.target.Target(target):
+ mod = relay.transform.InferType()(mod)
+ legalized_mod = relay.qnn.transform.Legalize()(mod)
+ assert "cast" in legalized_mod.astext() and "qnn.conv2d" in
legalized_mod.astext()
# Since same dtype, there should not be any transformation
with tvm.target.Target(
@@ -167,7 +168,7 @@ def test_qnn_legalize_qnn_conv2d():
#############################################################
# Check transformations for platforms with fast Int8 support.
#############################################################
- # Check no transformation for Intel VNNI.
+ # Check no transformation for Intel AVX512.
with tvm.target.Target("llvm -mcpu=skylake-avx512"):
mod = relay.transform.InferType()(mod)
legalized_mod = relay.qnn.transform.Legalize()(mod)
@@ -229,11 +230,12 @@ def test_qnn_legalize_qnn_dense():
#############################################################
# Check transformations for platforms with fast Int8 support.
#############################################################
- # Check that Intel VNNI gets picked up.
- with tvm.target.Target("llvm -mcpu=skylake-avx512"):
- mod = relay.transform.InferType()(mod)
- legalized_mod = relay.qnn.transform.Legalize()(mod)
- assert "cast" in legalized_mod.astext() and "qnn.dense" in
legalized_mod.astext()
+ # Check that Intel AVX512 (with or w/o VNNI) gets picked up.
+ for target in ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"]:
+ with tvm.target.Target(target):
+ mod = relay.transform.InferType()(mod)
+ legalized_mod = relay.qnn.transform.Legalize()(mod)
+ assert "cast" in legalized_mod.astext() and "qnn.dense" in
legalized_mod.astext()
# Since same dtype, there should not be any transformation
with tvm.target.Target(
@@ -260,7 +262,7 @@ def test_qnn_legalize_qnn_dense():
#############################################################
# Check transformations for platforms with fast Int8 support.
#############################################################
- # Check no transformation for Intel VNNI.
+ # Check no transformation for Intel AVX512.
with tvm.target.Target("llvm -mcpu=skylake-avx512"):
mod = relay.transform.InferType()(mod)
legalized_mod = relay.qnn.transform.Legalize()(mod)
diff --git a/tests/python/unittest/test_meta_schedule_vnni_integration.py
b/tests/python/unittest/test_meta_schedule_cpu_dot_product.py
similarity index 83%
rename from tests/python/unittest/test_meta_schedule_vnni_integration.py
rename to tests/python/unittest/test_meta_schedule_cpu_dot_product.py
index 3bbe916472..6dc72d6933 100644
--- a/tests/python/unittest/test_meta_schedule_vnni_integration.py
+++ b/tests/python/unittest/test_meta_schedule_cpu_dot_product.py
@@ -28,6 +28,7 @@ from tvm._ffi import register_func
from tvm.tir.schedule import BlockRV, Schedule
from tvm.tir.schedule.analysis import has_block
from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN
+from tvm.tir.tensor_intrin.x86 import AVX512_DOT_16x4_INTRIN as AVX512_INTRIN
logging.basicConfig(
format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s",
@@ -36,9 +37,9 @@ logging.basicConfig(
logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)
-def _schedule_dense(m: Optional[int], do_tune: bool):
+def _schedule_dense(m: Optional[int], do_tune: bool, intrin=VNNI_INTRIN):
"""Manually schedule a dense block, created from TE compute op via
CreatePrimFunc,
- using VNNI instruction.
+ using VNNI or AVX512 instructions.
"""
def schedule_fn(sch, dense_block: Optional[BlockRV] = None) -> bool:
@@ -47,7 +48,7 @@ def _schedule_dense(m: Optional[int], do_tune: bool):
if dense_block is None:
assert has_block(sch, "compute")
dense_block = sch.get_block("compute")
- assert "dense_vnni" in
sch.get(dense_block).annotations["schedule_rule"]
+ assert "dense_int8" in
sch.get(dense_block).annotations["schedule_rule"]
post_blocks = sch.get_consumers(dense_block)
if len(post_blocks) > 0:
@@ -90,7 +91,7 @@ def _schedule_dense(m: Optional[int], do_tune: bool):
dec = sch.decompose_reduction(dense_block, a_ko)
init_loop = sch.get_loops(dec)[-1]
sch.vectorize(init_loop)
- sch.tensorize(a_xi, VNNI_INTRIN)
+ sch.tensorize(a_xi, intrin)
return True
return schedule_fn
@@ -109,10 +110,10 @@ def _relay_dense(m, n, k):
out_dtype="int32",
)
relay_mod = tvm.IRModule.from_expr(out)
- data = np.random.uniform(1, 10, size=(m, k)).astype("uint8")
+ data = np.random.randint(0, 5, size=(m, k), dtype="uint8")
params = {
- "weight": np.random.uniform(1, 10, size=(n, k)).astype("int8"),
- "bias": np.random.uniform(1, 10, size=(n,)).astype("int32"),
+ "weight": np.random.randint(0, 5, size=(n, k), dtype="int8"),
+ "bias": np.random.randint(0, 5, size=(n,), dtype="int32"),
}
def f_check(lib, dev):
@@ -135,10 +136,7 @@ def _relay_dense(m, n, k):
return relay_mod, params, f_check
[email protected]_cascadelake
-def test_vnni_schedule_fn_database():
- m, n, k = 1024, 1024, 1024
- target = tvm.target.Target("llvm -mcpu=cascadelake -num-cores 4")
+def schedule_16x4_dense_fn_database(target, intrin, m=1024, n=1024, k=1024):
dev = tvm.cpu(0)
relay_mod, params, f_check = _relay_dense(m, n, k)
@@ -146,6 +144,7 @@ def test_vnni_schedule_fn_database():
_schedule_dense(
m=m,
do_tune=False,
+ intrin=intrin,
)
), tvm.transform.PassContext(
opt_level=3,
@@ -167,21 +166,32 @@ def test_vnni_schedule_fn_database():
@tvm.testing.requires_cascadelake
-def test_vnni_schedule_fn_tune():
+def test_vnni_schedule_fn_database():
+ target = tvm.target.Target("llvm -keys=x86,cpu -mcpu=cascadelake
-num-cores=4")
+ schedule_16x4_dense_fn_database(target, VNNI_INTRIN)
+
+
[email protected]_skylake_avx512
+def test_avx512_schedule_fn_database():
+ target = tvm.target.Target("llvm -keys=x86,cpu -mcpu=skylake-avx512
-num-cores=4")
+ schedule_16x4_dense_fn_database(target, AVX512_INTRIN, 16, 16, 16)
+
+
+def schedule_16x4_dense_fn_tune(target, intrin, m=1024, n=1024, k=1024):
# pylint: disable=W0105
"""
We can inject and apply a custom TIR scheduling to a TE compute of
interest, using
the "schedule_rule" annotation. For example, in topi/x86/dense.py we have
the following
- declaration for int8 dense targeting the VNNI instruction.
+ declaration for int8 dense targeting the VNNI or AVX512 instructions.
C = te.compute(
...
- attrs={"schedule_rule": "meta_schedule.x86.dense_vnni"},
+ attrs={"schedule_rule": "meta_schedule.x86.dense_int8"},
)
When the MetaSchedule encounters a TensorIR block with the "schedule_rule"
annotation,
it looks up the packed func registry for a function that is associated
with the given schedule
- rule key ("meta_schedule.x86.dense_vnni" in this example). The signature
of such custom
+ rule key ("meta_schedule.x86.dense_int8" in this example). The signature
of such custom
schedule functions must be
(tir.schedule.Schedule, tir.schedule.BlockRV) ->
[tir.schedule.Schedule].
@@ -191,14 +201,12 @@ def test_vnni_schedule_fn_tune():
The relevant code is in
`src/meta_schedule/space_generator/apply_custom_rule.cc`.
"""
- def schedule_rule_dense_vnni(sch: Schedule, dense_block: BlockRV):
- _schedule_dense(m=None, do_tune=True)(sch, dense_block)
+ def schedule_rule_dense_16x4(sch: Schedule, dense_block: BlockRV):
+ _schedule_dense(m=None, do_tune=True, intrin=intrin)(sch, dense_block)
return [sch]
- register_func("meta_schedule.x86.dense_vnni", schedule_rule_dense_vnni)
+ register_func("meta_schedule.x86.dense_int8", schedule_rule_dense_16x4,
override=True)
- m, n, k = 1024, 1024, 1024
- target = tvm.target.Target("llvm -keys=x86,cpu -mcpu=cascadelake
-num-cores=4")
dev = tvm.cpu(0)
relay_mod, params, f_check = _relay_dense(m, n, k)
@@ -247,6 +255,20 @@ def test_vnni_schedule_fn_tune():
f_check(lib, dev)
[email protected]_cascadelake
+def test_vnni_schedule_fn_tune():
+ target = tvm.target.Target("llvm -keys=x86,cpu -mcpu=cascadelake
-num-cores=4")
+ schedule_16x4_dense_fn_tune(target, VNNI_INTRIN)
+
+
[email protected]_skylake_avx512
+def test_avx512_schedule_fn_tune():
+ target = tvm.target.Target("llvm -keys=x86,cpu -mcpu=skylake-avx512
-num-cores=4")
+ schedule_16x4_dense_fn_tune(target, AVX512_INTRIN, 16, 16, 16)
+
+
if __name__ == """__main__""":
test_vnni_schedule_fn_database()
+ test_avx512_schedule_fn_database()
test_vnni_schedule_fn_tune()
+ test_avx512_schedule_fn_tune()
diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py
b/tests/python/unittest/test_meta_schedule_relay_integration.py
index d3731cfa1b..795890de08 100644
--- a/tests/python/unittest/test_meta_schedule_relay_integration.py
+++ b/tests/python/unittest/test_meta_schedule_relay_integration.py
@@ -316,9 +316,8 @@ def
test_meta_schedule_integration_extract_from_resnet_with_filter_func():
assert t.task_name in expected_task_names, t.task_name
[email protected]("Too slow on CI")
-def extract_task_qbert():
- def _test(mod, params, target):
+def extract_task_qbert(target, sch_rule_tag):
+ def _test(mod, params, target, sch_rule_tag):
extracted_tasks = ms.relay_integration.extract_tasks(mod, target,
params)
tune_tasks = list(
filter(
@@ -341,10 +340,20 @@ def extract_task_qbert():
annotations = sch.get(block).annotations
assert "schedule_rule" in annotations
- assert "vnni" in annotations["schedule_rule"]
+ assert sch_rule_tag in annotations["schedule_rule"]
mod, params, _ = load_quantized_bert_base(batch_size=1, seq_len=128)
- _test(mod, params, target="llvm -mcpu=cascadelake")
+ _test(mod, params, target=target, sch_rule_tag=sch_rule_tag)
+
+
[email protected]("Too slow on CI")
+def extract_task_qbert_vnni():
+ extract_task_qbert("llvm -mcpu=cascadelake", "vnni")
+
+
[email protected]("Too slow on CI")
+def extract_task_qbert_avx512():
+ extract_task_qbert("llvm -mcpu=skylake-avx512", "avx512")
@tvm.testing.skip_if_32bit(reason="Apparently the LLVM version on i386 image
is too old")
diff --git
a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py
b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py
index 54f342c3a5..4667626f17 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py
@@ -26,9 +26,10 @@ from tvm.script import tir as T
from tvm.target import Target
from tvm.tir.tensor_intrin.arm_cpu import DP4A_INTRIN
from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN
+from tvm.tir.tensor_intrin.x86 import AVX512_DOT_16x4_INTRIN as AVX512_INTRIN
-def test_vnni_conv2d_nchwc():
+def test_x86_conv2d_nchwc(intrin=VNNI_INTRIN, target="llvm -mcpu=cascadelake
-num-cores=4"):
@T.prim_func
def conv2d_nchwc(
placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"],
@@ -68,7 +69,7 @@ def test_vnni_conv2d_nchwc():
# fmt: off
@T.prim_func
- def vnni_conv2d_nchwc_0(placeholder: T.Buffer[(1, 4, 56, 56, 16),
"uint8"], placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"],
conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"]) -> None:
+ def x86_conv2d_nchwc_0(placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"],
placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], conv2d_NCHWc_int8:
T.Buffer[(1, 16, 56, 56, 16), "int32"]) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
conv2d_NCHWc_int8_global = T.alloc_buffer([1, 16, 56, 56, 16],
dtype="int32")
for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1 in
T.grid(1, 8, 28, 56, 1, 1, 2, 1, 1, 1):
@@ -86,7 +87,7 @@ def test_vnni_conv2d_nchwc():
ic_s_inner_o = T.axis.reduce(1, i9_0_1 + i9_0_0)
T.reads(placeholder[n, ic_outer, oh + kh, ow + kw,
ic_f_inner * 4 : ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw,
ic_f_inner, 0 : 16, 0 : 4])
T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, 0 :
16])
-
T.block_attr({"meta_schedule.auto_tensorize":"dot_16x4_vnni"})
+ T.block_attr({"meta_schedule.auto_tensorize":intrin})
with T.init():
for i4_1 in T.serial(16):
with T.block("conv2d_NCHWc_int8_init"):
@@ -113,7 +114,7 @@ def test_vnni_conv2d_nchwc():
conv2d_NCHWc_int8[v0, v1, v2, v3, v4] =
conv2d_NCHWc_int8_global[v0, v1, v2, v3, v4]
@T.prim_func
- def vnni_conv2d_nchwc_1(placeholder: T.Buffer[(1, 4, 56, 56, 16),
"uint8"], placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"],
conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"]) -> None:
+ def x86_conv2d_nchwc_1(placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"],
placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], conv2d_NCHWc_int8:
T.Buffer[(1, 16, 56, 56, 16), "int32"]) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
conv2d_NCHWc_int8_global = T.alloc_buffer([1, 16, 56, 56, 16],
dtype="int32")
for i0_0, i1_0, i2_0, i3_0, i4_0_0 in T.grid(1, 8, 28, 56, 1):
@@ -131,7 +132,7 @@ def test_vnni_conv2d_nchwc():
ic_s_inner_o = T.axis.reduce(1, i9_0_1 + i9_0_0)
T.reads(placeholder[n, ic_outer, oh + kh, ow + kw,
ic_f_inner * 4 : ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw,
ic_f_inner, 0 : 16, 0 : 4])
T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, 0 :
16])
-
T.block_attr({"meta_schedule.auto_tensorize":"dot_16x4_vnni"})
+ T.block_attr({"meta_schedule.auto_tensorize":intrin})
with T.init():
for i4_1 in T.serial(16):
with T.block("conv2d_NCHWc_int8_init"):
@@ -158,7 +159,7 @@ def test_vnni_conv2d_nchwc():
conv2d_NCHWc_int8[v0, v1, v2, v3, v4] =
conv2d_NCHWc_int8_global[v0, v1, v2, v3, v4]
@T.prim_func
- def vnni_conv2d_nchwc_2(placeholder: T.Buffer[(1, 4, 56, 56, 16),
"uint8"], placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"],
conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"]) -> None:
+ def x86_conv2d_nchwc_2(placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"],
placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], conv2d_NCHWc_int8:
T.Buffer[(1, 16, 56, 56, 16), "int32"]) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1,
i5_0, i6_0, i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1,
i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(1, 8, 28, 56, 1,
1, 2, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 2, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1, 1):
with T.block("conv2d_NCHWc_int8_o"):
@@ -174,7 +175,7 @@ def test_vnni_conv2d_nchwc():
ic_s_inner_o = T.axis.reduce(1, i9_0_1 + i9_0_0)
T.reads(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner
* 4 : ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw,
ic_f_inner, 0 : 16, 0 : 4])
T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16])
- T.block_attr({"meta_schedule.auto_tensorize":"dot_16x4_vnni"})
+ T.block_attr({"meta_schedule.auto_tensorize":intrin})
with T.init():
for i4_1 in T.serial(16):
with T.block("conv2d_NCHWc_int8_init"):
@@ -228,7 +229,6 @@ def test_vnni_conv2d_nchwc():
]
mod = conv2d_nchwc
- target = Target("llvm -mcpu=cascadelake -num-cores=4")
actual = generate_design_space(
kind="llvm",
mod=mod,
@@ -236,7 +236,7 @@ def test_vnni_conv2d_nchwc():
types=None,
sch_rules=[
ms.schedule_rule.MultiLevelTilingWithIntrin(
- VNNI_INTRIN,
+ intrin,
structure="SSRSRS",
tile_binds=None,
max_innermost_factor=64,
@@ -249,7 +249,7 @@ def test_vnni_conv2d_nchwc():
check_sketches(
mod,
sketches=actual,
- expected_mods=[vnni_conv2d_nchwc_0, vnni_conv2d_nchwc_1,
vnni_conv2d_nchwc_2],
+ expected_mods=[x86_conv2d_nchwc_0, x86_conv2d_nchwc_1,
x86_conv2d_nchwc_2],
expected_decisions=[decision_0, decision_1, decision_2],
)
@@ -417,7 +417,8 @@ def test_dp4a_dense_no_tensorize_2():
if __name__ == "__main__":
- test_vnni_conv2d_nchwc()
+ test_x86_conv2d_nchwc()
+ test_x86_conv2d_nchwc(AVX512_INTRIN, "llvm -mcpu=skylake-avx512
-num-cores=4")
test_dp4a_dense()
test_dp4a_dense_no_tensorize_1()
test_dp4a_dense_no_tensorize_2()
diff --git a/tests/python/unittest/test_meta_schedule_trace_apply.py
b/tests/python/unittest/test_meta_schedule_trace_apply.py
index 9a62207fa2..43b9eb8bbb 100644
--- a/tests/python/unittest/test_meta_schedule_trace_apply.py
+++ b/tests/python/unittest/test_meta_schedule_trace_apply.py
@@ -25,6 +25,8 @@ from tvm.tir.tensor_intrin.cuda import *
from tvm.target import Target
from tvm.target.codegen import llvm_lookup_intrinsic_id
+from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN
+
# fmt: off
@tvm.script.ir_module
@@ -2553,9 +2555,7 @@ def test_conv2d_int8_vnni():
l36, l37, l38, l39, l40, l41, l42, l43, l44, l45, l46, l47 =
sch.get_loops(block=b1)
sch.reorder(l42, l43, l44, l45, l46, l35, l33)
b48 = sch.blockize(loop=l35)
- sch.annotate(
- block_or_loop=b48, ann_key="meta_schedule.auto_tensorize",
ann_val="dot_16x4_vnni"
- )
+ sch.annotate(block_or_loop=b48,
ann_key="meta_schedule.auto_tensorize", ann_val=VNNI_INTRIN)
l49, l50, l51, l52, l53, l54, l55, l56, l57, l58 =
sch.get_loops(block=b48)
v59, v60, v61, v62 = sch.sample_perfect_tile(
loop=l49, n=4, max_innermost_factor=64, decision=[1, 1, 1, 1]
@@ -2729,7 +2729,7 @@ def test_conv2d_int8_vnni():
sch.vectorize(loop=l193)
b194 = sch.get_block(name="conv2d_NCHWc_int8_o_update",
func_name="main")
sch.unannotate(block_or_loop=b194,
ann_key="meta_schedule.auto_tensorize")
- sch.tensorize(block_or_loop=b194, tensor_intrin="dot_16x4_vnni")
+ sch.tensorize(block_or_loop=b194, tensor_intrin=VNNI_INTRIN)
vnni_id = llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512")
verify(
diff --git a/tests/python/unittest/test_tir_schedule_analysis.py
b/tests/python/unittest/test_tir_schedule_analysis.py
index e0667da6fe..38bd4bba14 100644
--- a/tests/python/unittest/test_tir_schedule_analysis.py
+++ b/tests/python/unittest/test_tir_schedule_analysis.py
@@ -146,7 +146,7 @@ def test_suggest_index_map_winograd():
@tvm.script.ir_module
-class DenseVNNIModule:
+class DenseTIRModule:
@T.prim_func
def main(
placeholder: T.Buffer[(1024, 1024), "uint8"],
@@ -170,7 +170,7 @@ class DenseVNNIModule:
@tvm.script.ir_module
-class Conv2dNCHWcVNNIModule:
+class Conv2dNCHWcTIRModule:
@T.prim_func
def main(
placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"],
@@ -202,7 +202,8 @@ class Conv2dNCHWcVNNIModule:
conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] =
conv2d_NCHWc_int8[
n, oc_chunk, oh, ow, oc_block
] + T.cast(
- placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4
+ ic_s_inner], "int32"
+ placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4
+ ic_s_inner],
+ "int32",
) * T.cast(
placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner,
oc_block, ic_s_inner],
"int32",
@@ -222,8 +223,8 @@ def collect_loops(prim_func):
return loops
-def test_get_tensorize_loop_mapping_dense_vnni():
- s = Schedule(DenseVNNIModule)
+def test_get_tensorize_loop_mapping_dense_16x4():
+ s = Schedule(DenseTIRModule)
block = s.get_block("compute")
info = get_tensorize_loop_mapping(s, block, dot_product_16x4_u8i8i32_desc)
@@ -240,8 +241,8 @@ def test_get_tensorize_loop_mapping_dense_vnni():
assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(loop_k)
-def test_get_tensorize_loop_mapping_conv2d_nchwc_vnni():
- s = Schedule(Conv2dNCHWcVNNIModule)
+def test_get_tensorize_loop_mapping_conv2d_nchwc_16x4():
+ s = Schedule(Conv2dNCHWcTIRModule)
block = s.get_block("conv2d_NCHWc_int8")
info = get_tensorize_loop_mapping(s, block, dot_product_16x4_u8i8i32_desc)
diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py
b/tests/python/unittest/test_tir_schedule_tensorize.py
index fc0bdc146c..4847f261a3 100644
--- a/tests/python/unittest/test_tir_schedule_tensorize.py
+++ b/tests/python/unittest/test_tir_schedule_tensorize.py
@@ -29,7 +29,7 @@ from tvm.tir.tensor_intrin.arm_cpu import (
ARM_DOT_4x4_i8_SDOT_INTRIN,
)
from tvm.tir.tensor_intrin.rocm import AMDGPU_SDOT4_INTRIN
-from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN
+from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN,
AVX512_DOT_16x4_INTRIN
from tvm.tir.tensor_intrin.hexagon import VRMPY_u8u8i32_INTRIN,
VDMPY_i16i16i32_INTRIN
# fmt: off
@@ -557,7 +557,7 @@ def get_matmul_packed(m, n, k, lhs_type, rhs_dtype="int8"):
return te.create_prim_func([X, W, matmul])
-def test_tensorize_vnni():
+def tensorize_16x4_test(intrin=VNNI_DOT_16x4_INTRIN):
m, n, k = 128, 128, 128
func = get_matmul_packed(m, n, k, "uint8")
@@ -572,11 +572,19 @@ def test_tensorize_vnni():
sch.reorder(ko, ji, ki)
sch.decompose_reduction(block, ko)
- sch.tensorize(ji, VNNI_DOT_16x4_INTRIN)
+ sch.tensorize(ji, intrin)
verify_trace_roundtrip(sch=sch, mod=func)
+def test_tensorize_vnni():
+ tensorize_16x4_test()
+
+
+def test_tensorize_avx512():
+ tensorize_16x4_test(AVX512_DOT_16x4_INTRIN)
+
+
def test_tensorize_arm_dot():
m, n, k = 128, 128, 128
diff --git a/tests/python/unittest/test_tir_schedule_transform.py
b/tests/python/unittest/test_tir_schedule_transform.py
index e812587e66..c068385f0a 100644
--- a/tests/python/unittest/test_tir_schedule_transform.py
+++ b/tests/python/unittest/test_tir_schedule_transform.py
@@ -18,11 +18,11 @@ import tvm
from tvm.script import tir as T
from tvm.tir import Schedule
from tvm.tir.schedule.transform import tile_with_tensor_intrin
-from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN
+from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN,
AVX512_DOT_16x4_INTRIN
@tvm.script.ir_module
-class DenseVNNIModule:
+class DenseTIRModule:
@T.prim_func
def main(
placeholder: T.Buffer[(1024, 1024), "uint8"],
@@ -46,7 +46,7 @@ class DenseVNNIModule:
@tvm.script.ir_module
-class DenseVNNIModuleTiled:
+class DenseTIRModuleTiled:
@T.prim_func
def main(
placeholder: T.Buffer[(1024, 1024), "uint8"],
@@ -72,7 +72,7 @@ class DenseVNNIModuleTiled:
@tvm.script.ir_module
-class Conv2dNCHWcVNNIModule:
+class Conv2dNCHWcTIRModule:
@T.prim_func
def main(
placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"],
@@ -104,7 +104,8 @@ class Conv2dNCHWcVNNIModule:
conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] =
conv2d_NCHWc_int8[
n, oc_chunk, oh, ow, oc_block
] + T.cast(
- placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4
+ ic_s_inner], "int32"
+ placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4
+ ic_s_inner],
+ "int32",
) * T.cast(
placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner,
oc_block, ic_s_inner],
"int32",
@@ -112,7 +113,7 @@ class Conv2dNCHWcVNNIModule:
@tvm.script.ir_module
-class Conv2dNCHWcVNNIModuleTiled:
+class Conv2dNCHWcTIRModuleTiled:
@T.prim_func
def main(
placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"],
@@ -141,35 +142,38 @@ class Conv2dNCHWcVNNIModuleTiled:
conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] =
conv2d_NCHWc_int8[
n, oc_chunk, oh, ow, oc_block
] + T.cast(
- placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4
+ ic_s_inner], "int32"
+ placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4
+ ic_s_inner],
+ "int32",
) * T.cast(
placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner,
oc_block, ic_s_inner],
"int32",
)
-def test_tile_with_tensor_intrin_dense_vnni():
- s = Schedule(DenseVNNIModule)
+def test_tile_with_tensor_intrin_dense(intrin=VNNI_DOT_16x4_INTRIN):
+ s = Schedule(DenseTIRModule)
block = s.get_block("compute")
- tiled_loop = tile_with_tensor_intrin(s, block, VNNI_DOT_16x4_INTRIN)
+ tiled_loop = tile_with_tensor_intrin(s, block, intrin)
_, _, _, i1_1, _ = s.get_loops(block)
assert s.get(tiled_loop) == s.get(i1_1)
- tvm.ir.assert_structural_equal(s.mod, DenseVNNIModuleTiled)
+ tvm.ir.assert_structural_equal(s.mod, DenseTIRModuleTiled)
-def test_tile_with_tensor_intrin_conv2d_nchwc_vnni():
- s = Schedule(Conv2dNCHWcVNNIModule)
+def test_tile_with_tensor_intrin_conv2d_nchwc(intrin=VNNI_DOT_16x4_INTRIN):
+ s = Schedule(Conv2dNCHWcTIRModule)
block = s.get_block("conv2d_NCHWc_int8")
- tiled_loop = tile_with_tensor_intrin(s, block, VNNI_DOT_16x4_INTRIN)
+ tiled_loop = tile_with_tensor_intrin(s, block, intrin)
tiled_loops = s.get_loops(block)
assert len(tiled_loops) == 12
assert s.get(tiled_loop) == s.get(tiled_loops[-2])
- tvm.ir.assert_structural_equal(s.mod, Conv2dNCHWcVNNIModuleTiled)
+ tvm.ir.assert_structural_equal(s.mod, Conv2dNCHWcTIRModuleTiled)
if __name__ == "__main__":
- test_tile_with_tensor_intrin_dense_vnni()
- test_tile_with_tensor_intrin_conv2d_nchwc_vnni()
+ test_tile_with_tensor_intrin_dense()
+ test_tile_with_tensor_intrin_dense(AVX512_DOT_16x4_INTRIN)
+ test_tile_with_tensor_intrin_conv2d_nchwc()
+ test_tile_with_tensor_intrin_conv2d_nchwc(AVX512_DOT_16x4_INTRIN)