This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 42f9a766ff [TOPI] Add padding for dense/batch matmul for x86 vnni
(#13385)
42f9a766ff is described below
commit 42f9a766ffdce46220dc46d9126998909829dd66
Author: Wuwei Lin <[email protected]>
AuthorDate: Tue Nov 15 12:32:18 2022 -0800
[TOPI] Add padding for dense/batch matmul for x86 vnni (#13385)
This added padding to make the shape of dense/batch matmul compatible with
VNNI instructions.
---
python/tvm/topi/x86/dense_alter_op.py | 35 ++++++++++++++++++++++++++++++-----
tests/python/relay/test_op_level1.py | 7 ++++---
tests/python/relay/test_op_level10.py | 16 ++++++++++++----
3 files changed, 46 insertions(+), 12 deletions(-)
diff --git a/python/tvm/topi/x86/dense_alter_op.py
b/python/tvm/topi/x86/dense_alter_op.py
index 0b195f487b..fd2b184a87 100644
--- a/python/tvm/topi/x86/dense_alter_op.py
+++ b/python/tvm/topi/x86/dense_alter_op.py
@@ -28,14 +28,13 @@ from .utils import target_has_vnni
from .. import nn
-def check_vnni_applicable(x, y):
+def check_vnni_applicable(x, y, allow_padding=False):
mcpu = tvm.target.Target.current().mcpu
return (
target_has_vnni(mcpu)
and "int8" in x.dtype
and "int8" in y.dtype
- and y.shape[-2] % 16 == 0
- and y.shape[-1] % 4 == 0
+ and (allow_padding or (y.shape[-2] % 16 == 0 and y.shape[-1] % 4 == 0))
)
@@ -87,7 +86,10 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type):
def vnni_legalize(inputs, arg_types, op, attrs, need_expand=False):
"""Legalizes s8, s8 -> s32 GEMM op for VNNI."""
- if check_vnni_applicable(arg_types[0], arg_types[1]) and
arg_types[0].dtype == "int8":
+ if (
+ check_vnni_applicable(arg_types[0], arg_types[1], allow_padding=True)
+ and arg_types[0].dtype == "int8"
+ ):
x, y = inputs
x = relay.cast(x, "int32")
x = relay.add(x, relay.const(128, "int32"))
@@ -98,7 +100,30 @@ def vnni_legalize(inputs, arg_types, op, attrs,
need_expand=False):
if need_expand:
adjust_shift = relay.expand_dims(adjust_shift, axis=1)
- out = op(x, y, **attrs)
+ analyzer = tvm.arith.Analyzer()
+ x_shape = arg_types[0].shape
+ y_shape = arg_types[1].shape
+ inst_n = 16
+ inst_k = 4
+ pad_n = analyzer.simplify((inst_n - y_shape[-2] % inst_n) % inst_n)
+ pad_k = analyzer.simplify((inst_k - y_shape[-1] % inst_k) % inst_k)
+ if pad_k != 0 or pad_n != 0:
+ ndim = len(x_shape)
+ unpadded_dims = [(0, 0)] * (ndim - 2)
+ padding_y = [(0, 0)] * (len(y_shape) - 2) + [(0, pad_n), (0,
pad_k)]
+ padded_y = relay.nn.pad(y, pad_width=padding_y, pad_value=0)
+ if pad_k != 0:
+ padding_x = [(0, 0)] * (len(x_shape) - 1) + [(0, pad_k)]
+ padded_x = relay.nn.pad(x, pad_width=padding_x, pad_value=0)
+ else:
+ padded_x = x
+ out = op(padded_x, padded_y, **attrs)
+ if pad_n != 0:
+ begin = [0] * len(x_shape)
+ end = x_shape[:-2] + [x_shape[-2], y_shape[-2]]
+ out = relay.strided_slice(out, begin, end, slice_mode="size")
+ else:
+ out = op(x, y, **attrs)
return relay.subtract(out, adjust_shift)
diff --git a/tests/python/relay/test_op_level1.py
b/tests/python/relay/test_op_level1.py
index 7884fa35a4..1c93ee766a 100644
--- a/tests/python/relay/test_op_level1.py
+++ b/tests/python/relay/test_op_level1.py
@@ -753,9 +753,10 @@ def test_bitserial_dense():
@tvm.testing.requires_cascadelake
-def test_dense_vnni():
- data_shape = (32, 96)
- weight_shape = (128, 96)
[email protected]("m,n,k", [(32, 128, 96), (32, 128, 97)])
+def test_dense_vnni(m, n, k):
+ data_shape = (m, k)
+ weight_shape = (n, k)
for data_dtype in ["uint8", "int8"]:
data = relay.var("data", shape=data_shape, dtype=data_dtype)
diff --git a/tests/python/relay/test_op_level10.py
b/tests/python/relay/test_op_level10.py
index 5134ab156b..619a0b5a93 100644
--- a/tests/python/relay/test_op_level10.py
+++ b/tests/python/relay/test_op_level10.py
@@ -474,10 +474,18 @@ def test_batch_matmul(executor_kind):
@tvm.testing.requires_cascadelake
-def test_batch_matmul_vnni():
- x_shape = (16, 32, 96)
- y_shape = (16, 128, 96)
- z_shape = (16, 32, 128)
[email protected](
+ "b,m,n,k",
+ [
+ (16, 32, 128, 96),
+ (16, 32, 128, 97),
+ (16, 32, 129, 96),
+ ],
+)
+def test_batch_matmul_vnni(b, m, n, k):
+ x_shape = (b, m, k)
+ y_shape = (b, n, k)
+ z_shape = (b, m, n)
for lhs_dtype in ["uint8", "int8"]:
x = relay.var("x", shape=x_shape, dtype=lhs_dtype)