This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0009a30  [TOPI] VNNI support for int8 dense (#10230)
0009a30 is described below

commit 0009a308d82a321f2399923ab14b4c088461b4f2
Author: Masahiro Masuda <[email protected]>
AuthorDate: Tue Feb 15 12:44:02 2022 +0900

    [TOPI] VNNI support for int8 dense (#10230)
    
    * wip
    
    * revert for now
    
    * simplify blocking
    
    * add bench script
    
    * update type rel
    
    * refactor tests
    
    * end to end compilation working
    
    * paralleize outer loop
    
    * add shape check
    
    * fused schedule first cut
    
    * restore original test
    
    * black
    
    * add vnni check
    
    * add relay test
    
    * skip on ci
    
    * check dtype
    
    * lint
    
    * make it tunable
    
    * minor cleanup
---
 python/tvm/relay/op/strategy/x86.py          | 26 ++++++--
 python/tvm/topi/x86/dense.py                 | 94 ++++++++++++++++++++++++++++
 python/tvm/topi/x86/dense_alter_op.py        | 15 ++++-
 src/relay/op/nn/nn.cc                        |  2 +-
 tests/python/contrib/test_gemm_acc32_vnni.py |  4 +-
 tests/python/relay/test_op_level1.py         | 52 ++++++++++-----
 6 files changed, 169 insertions(+), 24 deletions(-)

diff --git a/python/tvm/relay/op/strategy/x86.py 
b/python/tvm/relay/op/strategy/x86.py
index 20b6749..33e6862 100644
--- a/python/tvm/relay/op/strategy/x86.py
+++ b/python/tvm/relay/op/strategy/x86.py
@@ -530,11 +530,27 @@ def dense_strategy_cpu(attrs, inputs, out_type, target):
 def dense_pack_strategy_cpu(attrs, inputs, out_type, target):
     """dense_pack x86 strategy"""
     strategy = _op.OpStrategy()
-    strategy.add_implementation(
-        wrap_compute_dense(topi.x86.dense_pack),
-        wrap_topi_schedule(topi.x86.schedule_dense_pack),
-        name="dense_pack.x86",
-    )
+
+    if (
+        inputs[0].dtype == "uint8"
+        and inputs[1].dtype == "int8"
+        and out_type.dtype == "int32"
+        and attrs["weight_layout"] == "NC16n4c"
+    ):
+        strategy.add_implementation(
+            wrap_compute_dense(topi.x86.dense_vnni),
+            wrap_topi_schedule(topi.x86.schedule_dense_vnni),
+            name="dense_vnni.x86",
+            plevel=12,
+        )
+    else:
+        strategy.add_implementation(
+            wrap_compute_dense(topi.x86.dense_pack),
+            wrap_topi_schedule(topi.x86.schedule_dense_pack),
+            name="dense_pack.x86",
+            plevel=10,
+        )
+
     return strategy
 
 
diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py
index 9799ec0..fed5571 100644
--- a/python/tvm/topi/x86/dense.py
+++ b/python/tvm/topi/x86/dense.py
@@ -29,6 +29,7 @@ from tvm.contrib import mkldnn
 from .utils import get_simd_32bit_lanes
 from .. import generic, tag
 from ..utils import traverse_inline, get_const_tuple
+from .tensor_intrin import dot_16x1x16_uint8_int8_int32_cascadelake
 
 
 def _schedule_dense_pack_template(cfg, s, C, O):
@@ -279,6 +280,99 @@ def schedule_dense_pack(cfg, outs):
     return s
 
 
+def dense_vnni_compute(cfg, X, packed_w, bias=None):
+    """Compute for uint8 x int8 -> int32 dense"""
+    m, k = X.shape
+    n_o, _, n_i, _ = packed_w.shape
+    ak = te.reduce_axis((0, k), name="k")
+
+    C = te.compute(
+        (m, n_o * n_i),
+        lambda i, j: te.sum(
+            X[i, ak].astype("int32")
+            * packed_w[tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 
16, ak % 4].astype(
+                "int32"
+            ),
+            axis=ak,
+        ),
+        tag="dense_vnni",
+    )
+
+    if bias is not None:
+        C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j], 
tag=tag.BROADCAST)
+
+    a_y, _ = C.op.axis
+    cfg.define_split("tile_y", a_y, num_outputs=2)
+
+    return C
+
+
+def dense_vnni_schedule(cfg, s, C, O):
+    """Schedule dense compute using VNNI vpdpbusd instruction"""
+    # C: The output of GEMM
+    # O: The output of the fused op
+    def split_y(out):
+        default_y_split_factor = 32
+        a_y = out.op.axis[0]
+
+        if cfg.is_fallback:
+            return s[out].split(a_y, factor=default_y_split_factor)
+
+        return cfg["tile_y"].apply(s, out, a_y)
+
+    (a_k,) = C.op.reduce_axis
+
+    a_yo, a_yi = split_y(C)
+    a_xo, a_xi = s[C].split(C.op.axis[1], factor=16)
+    a_ko, a_ki = s[C].split(a_k, factor=4)
+
+    s[C].reorder(a_yo, a_xo, a_yi, a_ko, a_xi, a_ki)
+
+    pc = dot_16x1x16_uint8_int8_int32_cascadelake()
+    s[C].tensorize(a_xi, pc)
+
+    if C == O:
+        fused = s[O].fuse(a_yo, a_xo)
+    else:
+        a_yo, a_yi = split_y(O)
+        a_xo, a_xi = s[O].split(O.op.axis[1], factor=16)
+
+        s[O].reorder(a_yo, a_xo, a_yi, a_xi)
+        s[O].vectorize(a_xi)
+        s[C].compute_at(s[O], a_yi)
+
+        fused = s[O].fuse(a_yo, a_xo)
+
+    s[O].parallel(fused)
+
+    return s
+
+
[email protected]_topi_compute("dense_vnni.x86")
+def dense_vnni(cfg, data, weight, bias=None, out_dtype=None):
+    """Compute for uint8 x int8 -> int32 dense"""
+    if out_dtype is None:
+        out_dtype = data.dtype
+    assert len(weight.shape) == 4
+    assert data.dtype == "uint8" and weight.dtype == "int8"
+    _, _, n_inner, k_inner = get_const_tuple(weight.shape)  # out_dim
+    assert n_inner == 16 and k_inner == 4
+    return dense_vnni_compute(cfg, data, weight, bias)
+
+
[email protected]_topi_schedule("dense_vnni.x86")
+def schedule_dense_vnni(cfg, outs):
+    """Create a schedule for dense_vnni"""
+    s = te.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if "dense_vnni" in op.tag:
+            dense_vnni_schedule(cfg, s, op.output(0), outs[0])
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
+
+
 def matmul_blas_common(cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, 
transpose_b, lib):
     """Compute matmul/dense using a BLAS library"""
     M, K = get_const_tuple(tensor_a.shape)
diff --git a/python/tvm/topi/x86/dense_alter_op.py 
b/python/tvm/topi/x86/dense_alter_op.py
index 1d64261..273e5fa 100644
--- a/python/tvm/topi/x86/dense_alter_op.py
+++ b/python/tvm/topi/x86/dense_alter_op.py
@@ -24,6 +24,7 @@ from tvm import autotvm
 from .dense import _default_dense_pack_config
 from ..utils import get_const_tuple
 from ..nn import dense_alter_layout
+from .utils import target_has_vnni
 
 
 @dense_alter_layout.register(["cpu", "arm_cpu"])
@@ -34,8 +35,20 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type):
     out_dtype = out_type.dtype
     M, K = get_const_tuple(data_tensor.shape)
     N, _ = get_const_tuple(weight_tensor.shape)
+    mcpu = tvm.target.Target.current().mcpu
 
-    impl, outs = relay.backend.te_compiler.select_implementation(
+    if (
+        target_has_vnni(mcpu)
+        and data_tensor.dtype == "uint8"
+        and weight_tensor.dtype == "int8"
+        and weight_tensor.shape[0] % 16 == 0
+        and weight_tensor.shape[1] % 4 == 0
+    ):
+        # TODO(masahi): Support int8 x int8 case
+        weight_layout = "NC16n4c"
+        return relay.nn.contrib_dense_pack(inputs[0], inputs[1], 
weight_layout, None, out_dtype)
+
+    _, outs = relay.backend.te_compiler.select_implementation(
         relay.op.get("nn.dense"), attrs, tinfos, out_type, target
     )
     workload = autotvm.task.get_workload(outs)
diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc
index e0e0543..faa69d5 100644
--- a/src/relay/op/nn/nn.cc
+++ b/src/relay/op/nn/nn.cc
@@ -259,7 +259,7 @@ bool DensePackRel(const Array<Type>& types, int num_inputs, 
const Attrs& attrs,
   ICHECK(param != nullptr);
 
   ICHECK_EQ(data->shape.size(), 2) << "Only 2D data is supported";
-  ICHECK_EQ(weight->shape.size(), 3) << "Weight is not packed";
+  ICHECK(weight->shape.size() == 3 || weight->shape.size() == 4) << "Expect 
weight to be 3D or 4D";
 
   Array<tvm::PrimExpr> oshape = data->shape;
   oshape.Set(1, weight->shape[0] * weight->shape[2]);
diff --git a/tests/python/contrib/test_gemm_acc32_vnni.py 
b/tests/python/contrib/test_gemm_acc32_vnni.py
index 3d03825..9cec823 100644
--- a/tests/python/contrib/test_gemm_acc32_vnni.py
+++ b/tests/python/contrib/test_gemm_acc32_vnni.py
@@ -57,7 +57,9 @@ def test_fc_int8_acc32():
             (m, n),
             lambda i, j: te.sum(
                 X[i, ak].astype("int32")
-                * packedW[j / 16, (ak / 4) * 16 + j % 16, ak % 
4].astype("int32"),
+                * packedW[
+                    tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4) * 16 + j 
% 16, ak % 4
+                ].astype("int32"),
                 axis=ak,
             ),
             name="F",
diff --git a/tests/python/relay/test_op_level1.py 
b/tests/python/relay/test_op_level1.py
index 791edcc..d505b7e 100644
--- a/tests/python/relay/test_op_level1.py
+++ b/tests/python/relay/test_op_level1.py
@@ -20,7 +20,7 @@ import tvm
 from tvm import te
 import scipy
 from tvm import relay
-from tvm.relay import transform
+import pytest
 from tvm.relay.testing import run_infer_type
 import tvm.topi.testing
 from tvm.contrib.nvcc import have_fp16
@@ -634,19 +634,39 @@ def test_bitserial_dense():
     assert yy.checked_type == relay.TensorType((m, 32), "int16")
 
 
[email protected]("Requires cascadelake")
+def test_dense_vnni():
+    data_shape = (32, 96)
+    weight_shape = (128, 96)
+
+    data = relay.var("data", shape=data_shape, dtype="uint8")
+    weight = relay.var("weight", shape=weight_shape, dtype="int8")
+    bias = relay.var("bias", shape=(weight_shape[0],), dtype="int32")
+    dense = relay.nn.dense(data, weight, out_dtype="int32")
+    out = relay.nn.bias_add(dense, bias)
+    mod = tvm.IRModule.from_expr(out)
+
+    target = "llvm -mcpu=cascadelake"
+    with tvm.transform.PassContext(opt_level=3):
+        lib = relay.build(mod, target=target)
+
+    dev = tvm.device(target, 0)
+    runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
+
+    a = np.random.uniform(1, 10, size=data_shape).astype("uint8")
+    b = np.random.uniform(1, 10, size=weight_shape).astype("int8")
+    c = np.random.uniform(1, 10, size=(weight_shape[0],)).astype("int32")
+
+    runtime.set_input("data", a)
+    runtime.set_input("weight", b)
+    runtime.set_input("bias", c)
+    runtime.run()
+
+    out = runtime.get_output(0).numpy()
+    ref = np.dot(a, b.transpose()) + c
+
+    np.testing.assert_equal(out, ref)
+
+
 if __name__ == "__main__":
-    test_concatenate()
-    test_bias_add()
-    test_bias_add_type_failure()
-    test_unary_op()
-    test_binary_op()
-    test_expand_dims_infer_type()
-    test_expand_dims()
-    test_softmax()
-    test_log_softmax()
-    test_dropout()
-    test_batch_norm()
-    test_matmul()
-    test_dense()
-    test_bitserial_dense()
-    test_dense_dtype()
+    pytest.main([__file__])

Reply via email to