This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2d8ac1d [MKL] Fix offloading of batch_matmul to MKL (#6752)
2d8ac1d is described below
commit 2d8ac1db228929d75d06ee3663a239c611332dfd
Author: masahi <[email protected]>
AuthorDate: Sun Oct 25 17:12:27 2020 +0900
[MKL] Fix offloading of batch_matmul to MKL (#6752)
* fix mkl offloading of batch matmul
* name fix and add doc
* add doc for lib arg
Co-authored-by: masa <[email protected]>
---
python/tvm/relay/op/strategy/x86.py | 7 +++++++
python/tvm/topi/x86/batch_matmul.py | 30 +++++++++++++++++++++++++-----
2 files changed, 32 insertions(+), 5 deletions(-)
diff --git a/python/tvm/relay/op/strategy/x86.py
b/python/tvm/relay/op/strategy/x86.py
index e2a82d3..3c5735b 100644
--- a/python/tvm/relay/op/strategy/x86.py
+++ b/python/tvm/relay/op/strategy/x86.py
@@ -377,6 +377,13 @@ def batch_matmul_strategy_cpu(attrs, inputs, out_type,
target):
name="batch_matmul_cblas.x86",
plevel=15,
)
+ if "mkl" in target.libs:
+ strategy.add_implementation(
+ wrap_compute_batch_matmul(topi.x86.batch_matmul_mkl),
+ wrap_topi_schedule(topi.x86.schedule_batch_matmul_mkl),
+ name="batch_matmul_mkl.x86",
+ plevel=15,
+ )
return strategy
diff --git a/python/tvm/topi/x86/batch_matmul.py
b/python/tvm/topi/x86/batch_matmul.py
index 4e5f6ef..100bdf2 100644
--- a/python/tvm/topi/x86/batch_matmul.py
+++ b/python/tvm/topi/x86/batch_matmul.py
@@ -19,7 +19,7 @@
from tvm import te
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity
-from tvm.contrib import cblas
+from tvm.contrib import cblas, mkl
from .. import generic
from ..util import traverse_inline, get_const_tuple, get_max_power2_factor
@@ -137,10 +137,9 @@ def _default_batch_matmul_config(cfg, M, N, K):
cfg["tile_y"] = SplitEntity([M // y_bn, y_bn])
[email protected]_topi_compute("batch_matmul_cblas.x86")
-def batch_matmul_cblas(cfg, x, y, out_shape=None):
+def batch_matmul_blas_common(cfg, x, y, out_shape, lib):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
- data in batch.
+ data in batch, using one of BLAS libraries.
Parameters
----------
@@ -152,6 +151,8 @@ def batch_matmul_cblas(cfg, x, y, out_shape=None):
3-D with shape [batch, N, K]
out_shape : tuple or None
Shape of the output
+ lib : A contrib module which implements batch_matmul funtion
+ cblas and mkl are supported
Returns
-------
@@ -168,9 +169,28 @@ def batch_matmul_cblas(cfg, x, y, out_shape=None):
assert out_shape[1] == M, "got invalid output shape"
assert out_shape[2] == N, "got invalid output shape"
cfg.add_flop(XB * M * N * XK * 2)
- return cblas.batch_matmul(x, y, False, True)
+ return lib.batch_matmul(x, y, False, True)
+
+
[email protected]_topi_compute("batch_matmul_cblas.x86")
+def batch_matmul_cblas(cfg, x, y, out_shape=None):
+ """Compute batch_matmul using cblas"""
+ return batch_matmul_blas_common(cfg, x, y, out_shape, cblas)
@autotvm.register_topi_schedule("batch_matmul_cblas.x86")
def schedule_batch_matmul_cblas(_, outs):
+ """Create schedule for batch_matmul_cblas"""
+ return generic.schedule_extern(outs)
+
+
[email protected]_topi_compute("batch_matmul_mkl.x86")
+def batch_matmul_mkl(cfg, x, y, out_shape=None):
+ """Compute batch_matmul using mkl"""
+ return batch_matmul_blas_common(cfg, x, y, out_shape, mkl)
+
+
[email protected]_topi_schedule("batch_matmul_mkl.x86")
+def schedule_batch_matmul_mkl(_, outs):
+ """Create schedule for batch_matmul_mul"""
return generic.schedule_extern(outs)