This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 42f9a766ff [TOPI] Add padding for dense/batch matmul for x86 vnni 
(#13385)
42f9a766ff is described below

commit 42f9a766ffdce46220dc46d9126998909829dd66
Author: Wuwei Lin <[email protected]>
AuthorDate: Tue Nov 15 12:32:18 2022 -0800

    [TOPI] Add padding for dense/batch matmul for x86 vnni (#13385)
    
    This added padding to make the shape of dense/batch matmul compatible with 
VNNI instructions.
---
 python/tvm/topi/x86/dense_alter_op.py | 35 ++++++++++++++++++++++++++++++-----
 tests/python/relay/test_op_level1.py  |  7 ++++---
 tests/python/relay/test_op_level10.py | 16 ++++++++++++----
 3 files changed, 46 insertions(+), 12 deletions(-)

diff --git a/python/tvm/topi/x86/dense_alter_op.py 
b/python/tvm/topi/x86/dense_alter_op.py
index 0b195f487b..fd2b184a87 100644
--- a/python/tvm/topi/x86/dense_alter_op.py
+++ b/python/tvm/topi/x86/dense_alter_op.py
@@ -28,14 +28,13 @@ from .utils import target_has_vnni
 from .. import nn
 
 
-def check_vnni_applicable(x, y):
+def check_vnni_applicable(x, y, allow_padding=False):
     mcpu = tvm.target.Target.current().mcpu
     return (
         target_has_vnni(mcpu)
         and "int8" in x.dtype
         and "int8" in y.dtype
-        and y.shape[-2] % 16 == 0
-        and y.shape[-1] % 4 == 0
+        and (allow_padding or (y.shape[-2] % 16 == 0 and y.shape[-1] % 4 == 0))
     )
 
 
@@ -87,7 +86,10 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type):
 
 def vnni_legalize(inputs, arg_types, op, attrs, need_expand=False):
     """Legalizes s8, s8 -> s32 GEMM op for VNNI."""
-    if check_vnni_applicable(arg_types[0], arg_types[1]) and 
arg_types[0].dtype == "int8":
+    if (
+        check_vnni_applicable(arg_types[0], arg_types[1], allow_padding=True)
+        and arg_types[0].dtype == "int8"
+    ):
         x, y = inputs
         x = relay.cast(x, "int32")
         x = relay.add(x, relay.const(128, "int32"))
@@ -98,7 +100,30 @@ def vnni_legalize(inputs, arg_types, op, attrs, 
need_expand=False):
         if need_expand:
             adjust_shift = relay.expand_dims(adjust_shift, axis=1)
 
-        out = op(x, y, **attrs)
+        analyzer = tvm.arith.Analyzer()
+        x_shape = arg_types[0].shape
+        y_shape = arg_types[1].shape
+        inst_n = 16
+        inst_k = 4
+        pad_n = analyzer.simplify((inst_n - y_shape[-2] % inst_n) % inst_n)
+        pad_k = analyzer.simplify((inst_k - y_shape[-1] % inst_k) % inst_k)
+        if pad_k != 0 or pad_n != 0:
+            ndim = len(x_shape)
+            unpadded_dims = [(0, 0)] * (ndim - 2)
+            padding_y = [(0, 0)] * (len(y_shape) - 2) + [(0, pad_n), (0, 
pad_k)]
+            padded_y = relay.nn.pad(y, pad_width=padding_y, pad_value=0)
+            if pad_k != 0:
+                padding_x = [(0, 0)] * (len(x_shape) - 1) + [(0, pad_k)]
+                padded_x = relay.nn.pad(x, pad_width=padding_x, pad_value=0)
+            else:
+                padded_x = x
+            out = op(padded_x, padded_y, **attrs)
+            if pad_n != 0:
+                begin = [0] * len(x_shape)
+                end = x_shape[:-2] + [x_shape[-2], y_shape[-2]]
+                out = relay.strided_slice(out, begin, end, slice_mode="size")
+        else:
+            out = op(x, y, **attrs)
 
         return relay.subtract(out, adjust_shift)
 
diff --git a/tests/python/relay/test_op_level1.py 
b/tests/python/relay/test_op_level1.py
index 7884fa35a4..1c93ee766a 100644
--- a/tests/python/relay/test_op_level1.py
+++ b/tests/python/relay/test_op_level1.py
@@ -753,9 +753,10 @@ def test_bitserial_dense():
 
 
 @tvm.testing.requires_cascadelake
-def test_dense_vnni():
-    data_shape = (32, 96)
-    weight_shape = (128, 96)
[email protected]("m,n,k", [(32, 128, 96), (32, 128, 97)])
+def test_dense_vnni(m, n, k):
+    data_shape = (m, k)
+    weight_shape = (n, k)
 
     for data_dtype in ["uint8", "int8"]:
         data = relay.var("data", shape=data_shape, dtype=data_dtype)
diff --git a/tests/python/relay/test_op_level10.py 
b/tests/python/relay/test_op_level10.py
index 5134ab156b..619a0b5a93 100644
--- a/tests/python/relay/test_op_level10.py
+++ b/tests/python/relay/test_op_level10.py
@@ -474,10 +474,18 @@ def test_batch_matmul(executor_kind):
 
 
 @tvm.testing.requires_cascadelake
-def test_batch_matmul_vnni():
-    x_shape = (16, 32, 96)
-    y_shape = (16, 128, 96)
-    z_shape = (16, 32, 128)
[email protected](
+    "b,m,n,k",
+    [
+        (16, 32, 128, 96),
+        (16, 32, 128, 97),
+        (16, 32, 129, 96),
+    ],
+)
+def test_batch_matmul_vnni(b, m, n, k):
+    x_shape = (b, m, k)
+    y_shape = (b, n, k)
+    z_shape = (b, m, n)
 
     for lhs_dtype in ["uint8", "int8"]:
         x = relay.var("x", shape=x_shape, dtype=lhs_dtype)

Reply via email to