This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2d8ac1d  [MKL] Fix offloading of batch_matmul to MKL (#6752)
2d8ac1d is described below

commit 2d8ac1db228929d75d06ee3663a239c611332dfd
Author: masahi <[email protected]>
AuthorDate: Sun Oct 25 17:12:27 2020 +0900

    [MKL] Fix offloading of batch_matmul to MKL (#6752)
    
    * fix mkl offloading of batch matmul
    
    * name fix and add doc
    
    * add doc for lib arg
    
    Co-authored-by: masa <[email protected]>
---
 python/tvm/relay/op/strategy/x86.py |  7 +++++++
 python/tvm/topi/x86/batch_matmul.py | 30 +++++++++++++++++++++++++-----
 2 files changed, 32 insertions(+), 5 deletions(-)

diff --git a/python/tvm/relay/op/strategy/x86.py 
b/python/tvm/relay/op/strategy/x86.py
index e2a82d3..3c5735b 100644
--- a/python/tvm/relay/op/strategy/x86.py
+++ b/python/tvm/relay/op/strategy/x86.py
@@ -377,6 +377,13 @@ def batch_matmul_strategy_cpu(attrs, inputs, out_type, 
target):
             name="batch_matmul_cblas.x86",
             plevel=15,
         )
+    if "mkl" in target.libs:
+        strategy.add_implementation(
+            wrap_compute_batch_matmul(topi.x86.batch_matmul_mkl),
+            wrap_topi_schedule(topi.x86.schedule_batch_matmul_mkl),
+            name="batch_matmul_mkl.x86",
+            plevel=15,
+        )
     return strategy
 
 
diff --git a/python/tvm/topi/x86/batch_matmul.py 
b/python/tvm/topi/x86/batch_matmul.py
index 4e5f6ef..100bdf2 100644
--- a/python/tvm/topi/x86/batch_matmul.py
+++ b/python/tvm/topi/x86/batch_matmul.py
@@ -19,7 +19,7 @@
 from tvm import te
 from tvm import autotvm
 from tvm.autotvm.task.space import SplitEntity
-from tvm.contrib import cblas
+from tvm.contrib import cblas, mkl
 from .. import generic
 from ..util import traverse_inline, get_const_tuple, get_max_power2_factor
 
@@ -137,10 +137,9 @@ def _default_batch_matmul_config(cfg, M, N, K):
     cfg["tile_y"] = SplitEntity([M // y_bn, y_bn])
 
 
[email protected]_topi_compute("batch_matmul_cblas.x86")
-def batch_matmul_cblas(cfg, x, y, out_shape=None):
+def batch_matmul_blas_common(cfg, x, y, out_shape, lib):
     """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
-    data in batch.
+    data in batch, using one of BLAS libraries.
 
     Parameters
     ----------
@@ -152,6 +151,8 @@ def batch_matmul_cblas(cfg, x, y, out_shape=None):
         3-D with shape [batch, N, K]
     out_shape : tuple or None
         Shape of the output
+    lib : A contrib module which implements batch_matmul funtion
+        cblas and mkl are supported
 
     Returns
     -------
@@ -168,9 +169,28 @@ def batch_matmul_cblas(cfg, x, y, out_shape=None):
         assert out_shape[1] == M, "got invalid output shape"
         assert out_shape[2] == N, "got invalid output shape"
     cfg.add_flop(XB * M * N * XK * 2)
-    return cblas.batch_matmul(x, y, False, True)
+    return lib.batch_matmul(x, y, False, True)
+
+
[email protected]_topi_compute("batch_matmul_cblas.x86")
+def batch_matmul_cblas(cfg, x, y, out_shape=None):
+    """Compute batch_matmul using cblas"""
+    return batch_matmul_blas_common(cfg, x, y, out_shape, cblas)
 
 
 @autotvm.register_topi_schedule("batch_matmul_cblas.x86")
 def schedule_batch_matmul_cblas(_, outs):
+    """Create schedule for batch_matmul_cblas"""
+    return generic.schedule_extern(outs)
+
+
[email protected]_topi_compute("batch_matmul_mkl.x86")
+def batch_matmul_mkl(cfg, x, y, out_shape=None):
+    """Compute batch_matmul using mkl"""
+    return batch_matmul_blas_common(cfg, x, y, out_shape, mkl)
+
+
[email protected]_topi_schedule("batch_matmul_mkl.x86")
+def schedule_batch_matmul_mkl(_, outs):
+    """Create schedule for batch_matmul_mul"""
     return generic.schedule_extern(outs)

Reply via email to