This is an automated email from the ASF dual-hosted git repository.

jcf94 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6d1ced0  [Matmul] Add matmul op (#8234)
6d1ced0 is described below

commit 6d1ced080390d221d1d64f37038cc869e0127f45
Author: Chenfan <jc...@outlook.com>
AuthorDate: Wed Jun 30 22:29:43 2021 +0800

    [Matmul] Add matmul op (#8234)
    
    * Add Matmul Op
    
    * Recover DenseAttrs
    
    * Add grad for matmul & some update
    
    * Update matmul cuda default schedule
    
    * Add blas support for matmul
    
    * Lint fix add update doc strings
---
 include/tvm/relay/attrs/nn.h                       |  26 ++++
 python/tvm/relay/frontend/tensorflow.py            |  19 ++-
 python/tvm/relay/frontend/tensorflow_ops.py        |  20 ++-
 python/tvm/relay/op/_tensor_grad.py                |  29 +++++
 python/tvm/relay/op/nn/_nn.py                      |  63 ++++++++--
 python/tvm/relay/op/nn/nn.py                       |  44 +++++++
 python/tvm/relay/op/op_attrs.py                    |   5 +
 python/tvm/relay/op/strategy/cuda.py               |  32 +++++
 python/tvm/relay/op/strategy/generic.py            |  36 ++++++
 python/tvm/relay/op/strategy/x86.py                |  72 +++++++++++
 python/tvm/topi/cuda/dense.py                      |  50 ++++++--
 python/tvm/topi/generic/nn.py                      |  17 +++
 python/tvm/topi/gpu/dense.py                       |  30 +++++
 python/tvm/topi/nn/dense.py                        | 138 ++++++++++++++++++---
 python/tvm/topi/x86/dense.py                       | 119 ++++++++++++------
 rust/tvm/src/ir/relay/attrs/nn.rs                  |  12 ++
 src/relay/op/make_op.h                             |   3 +
 src/relay/op/nn/nn.cc                              |  40 +++++-
 src/relay/op/nn/nn.h                               |  72 ++++++-----
 src/relay/qnn/op/dense.cc                          |   2 +-
 .../transforms/auto_scheduler_layout_rewrite.cc    |  10 +-
 tests/python/frontend/tensorflow/test_forward.py   |  12 +-
 tests/python/relay/test_op_grad_level2.py          |  17 +++
 tests/python/relay/test_op_level1.py               |  63 +++++++++-
 tests/python/topi/python/test_topi_matmul.py       |  26 ++++
 25 files changed, 842 insertions(+), 115 deletions(-)

diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h
index dc20267..3c75745 100644
--- a/include/tvm/relay/attrs/nn.h
+++ b/include/tvm/relay/attrs/nn.h
@@ -961,6 +961,32 @@ struct AvgPool3DAttrs : public 
tvm::AttrsNode<AvgPool3DAttrs> {
   }
 };
 
+/*! \brief Attributes for matmul operator */
+struct MatmulAttrs : public tvm::AttrsNode<MatmulAttrs> {
+  IndexExpr units;
+  DataType out_dtype;
+  bool transpose_a;
+  bool transpose_b;
+  tvm::String auto_scheduler_rewritten_layout;  // The layout after 
auto-scheduler's layout rewrite
+
+  TVM_DECLARE_ATTRS(MatmulAttrs, "relay.attrs.MatmulAttrs") {
+    TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense 
transformation.");
+
+    // use 0 bits to indicate none.
+    TVM_ATTR_FIELD(out_dtype)
+        .set_default(NullValue<DataType>())
+        .describe("Output data type, set to explicit type under mixed 
precision setting");
+
+    TVM_ATTR_FIELD(transpose_a)
+        .set_default(false)
+        .describe("Whether the first input tensor is in transposed format.");
+
+    TVM_ATTR_FIELD(transpose_b)
+        .set_default(false)
+        .describe("Whether the second input tensor is in transposed format.");
+  }
+};
+
 /*! \brief Attributes for dense operator */
 struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
   IndexExpr units;
diff --git a/python/tvm/relay/frontend/tensorflow.py 
b/python/tvm/relay/frontend/tensorflow.py
index 0bdec95..e297398 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -44,6 +44,16 @@ from .tensorflow_ops import _get_more_static_shape
 
 __all__ = ["from_tensorflow"]
 
+# The default configurations of Relay TensorFlow frontend.
+TF_DEFAULT_CONFIGS = {
+    # By default, TVM converts `tf.matmul` to `transpose(weight) + nn.dense`, 
which introduces
+    # unnecessary overhead in weight transpose. Change this flag to False to 
directly convert to
+    # `nn.matmul` to get rid of the overhead.
+    # However, please note that `nn.matmul` is in experimental so it may have 
some performance
+    # issues.
+    "use_dense": True,
+}
+
 # compatible operators that do NOT require any conversion.
 _identity_list = []
 
@@ -1204,7 +1214,7 @@ class SubGraphProto(GraphProto):
         return func, self._params
 
 
-def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
+def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, 
use_dense_op=True):
     """Load tensorflow graph which is a python tensorflow graph object into 
relay.
     The companion parameters will be handled automatically.
 
@@ -1222,6 +1232,11 @@ def from_tensorflow(graph, layout="NHWC", shape=None, 
outputs=None):
     outputs : List of output tensor names (Optional)
         if not specified then the last node is assumed as graph output.
 
+    use_dense_op : bool (Optional) = True
+        Ture to convert `tf.matmul` to `nn.dense`, else to `nn.matmul`.
+        The `nn.dense` op requires the data tensor to be non-transposed and 
weight tensor to be
+        transposed, may insert extra `transpose` to the original graph.
+
     Returns
     -------
     mod : tvm.IRModule
@@ -1230,6 +1245,8 @@ def from_tensorflow(graph, layout="NHWC", shape=None, 
outputs=None):
     params : dict of str to tvm.nd.NDArray
         Dict of converted parameters stored in tvm.nd.NDArray format
     """
+    global TF_DEFAULT_CONFIGS
+    TF_DEFAULT_CONFIGS["use_dense"] = use_dense_op
 
     g = GraphProto()
     mod, params = g.from_tensorflow(graph, layout, shape, outputs)
diff --git a/python/tvm/relay/frontend/tensorflow_ops.py 
b/python/tvm/relay/frontend/tensorflow_ops.py
index be15f83..004174f 100644
--- a/python/tvm/relay/frontend/tensorflow_ops.py
+++ b/python/tvm/relay/frontend/tensorflow_ops.py
@@ -1113,13 +1113,23 @@ def _no_op():
 
 def _matmul():
     def _impl(inputs, attr, params, mod):
+        from .tensorflow import TF_DEFAULT_CONFIGS
+
         channels = _infer_channels(inputs[1], not attr["transpose_b"])
-        if attr["transpose_a"]:
-            inputs[0] = _op.transpose(inputs[0], axes=(1, 0))
-        if not attr["transpose_b"]:
-            inputs[1] = _op.transpose(inputs[1], axes=(1, 0))
+        if TF_DEFAULT_CONFIGS["use_dense"]:
+            if attr["transpose_a"]:
+                inputs[0] = _op.transpose(inputs[0], axes=(1, 0))
+            if not attr["transpose_b"]:
+                inputs[1] = _op.transpose(inputs[1], axes=(1, 0))
+            return AttrCvt(
+                op_name="dense",
+                extras={"units": channels},
+                ignores=["transpose_a", "transpose_b", "T"],
+            )(inputs, attr)
         return AttrCvt(
-            op_name="dense", extras={"units": channels}, 
ignores=["transpose_a", "transpose_b", "T"]
+            op_name="matmul",
+            extras={"units": channels},
+            ignores=["T"],
         )(inputs, attr)
 
     return _impl
diff --git a/python/tvm/relay/op/_tensor_grad.py 
b/python/tvm/relay/op/_tensor_grad.py
index 09b1435..fa2772c 100644
--- a/python/tvm/relay/op/_tensor_grad.py
+++ b/python/tvm/relay/op/_tensor_grad.py
@@ -554,6 +554,35 @@ def dense_grad(orig, grad):
     ]
 
 
+@register_gradient("nn.matmul")
+def matmul_grad(orig, grad):
+    """Returns [grad' @ tensor_b, tensor_a @ grad']"""
+    tensor_a, tensor_b = orig.args
+    if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (True, True):
+        return [
+            collapse_sum_like(
+                _nn.matmul(tensor_b, grad, transpose_a=True, 
transpose_b=True), tensor_a
+            ),
+            collapse_sum_like(
+                _nn.matmul(grad, tensor_a, transpose_a=True, 
transpose_b=True), tensor_b
+            ),
+        ]
+    if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (True, False):
+        return [
+            collapse_sum_like(_nn.matmul(tensor_b, grad, transpose_b=True), 
tensor_a),
+            collapse_sum_like(_nn.matmul(tensor_a, grad), tensor_b),
+        ]
+    if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (False, True):
+        # Keep using Dense op here for not involving extra ops
+        # TODO(jcf94): Merge all to nn.matmul when it is finally ready
+        return dense_grad(orig, grad)
+    # (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (False, False)
+    return [
+        collapse_sum_like(_nn.matmul(grad, tensor_b, transpose_b=True), 
tensor_a),
+        collapse_sum_like(_nn.matmul(tensor_a, grad, transpose_a=True), 
tensor_b),
+    ]
+
+
 @register_gradient("nn.batch_matmul")
 def batch_matmul_grad(orig, grad):
     """gradient for nn.batch_matmul: in einsum LHS_bik,RHS_bjk->RES_bij
diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py
index 04d38ce..056cb56 100644
--- a/python/tvm/relay/op/nn/_nn.py
+++ b/python/tvm/relay/op/nn/_nn.py
@@ -52,6 +52,32 @@ reg.register_schedule("nn.log_softmax", 
strategy.schedule_log_softmax)
 reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)
 
 
+@reg.register_legalize("nn.matmul")
+def legalize_matmul(attrs, inputs, types):
+    """Legalize matmul op.
+
+    Parameters
+    ----------
+    attrs : tvm.ir.Attrs
+        Attributes of current matmul
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    types : list of types
+        List of input and output types
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The legalized expr
+    """
+    return topi.nn.matmul_legalize(attrs, inputs, types)
+
+
+# matmul
+reg.register_strategy("nn.matmul", strategy.matmul_strategy)
+reg.register_pattern("nn.matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
+
+
 @reg.register_legalize("nn.dense")
 def legalize_dense(attrs, inputs, types):
     """Legalize dense op.
@@ -1160,21 +1186,44 @@ def batch_flatten_shape_func(attrs, inputs, _):
 
 
 @script
-def _dense_shape_func(data_shape, weight_shape):
-    out = output_tensor((data_shape.shape[0],), "int64")
+def _matmul_shape_func(tensor_a_shape, tensor_b_shape, transpose_a, 
transpose_b):
+    out = output_tensor((tensor_a_shape.shape[0],), "int64")
     for i in const_range(out.shape[0] - 1):
-        out[i] = data_shape[i]
-    out[out.shape[0] - 1] = weight_shape[0]
+        out[i] = tensor_a_shape[i]
+    if transpose_a:
+        out[out.shape[0] - 2] = out[out.shape[0] - 1]
+    out[out.shape[0] - 1] = tensor_b_shape[0] if transpose_b else 
tensor_b_shape[1]
 
     return out
 
 
+@reg.register_shape_func("nn.matmul", False)
+def matmul_shape_func(attrs, inputs, _):
+    """Shape function for matmul op."""
+    ret = [
+        _matmul_shape_func(
+            inputs[0],
+            inputs[1],
+            expr.IntImm("bool", attrs.transpose_a),
+            expr.IntImm("bool", attrs.transpose_b),
+        )
+    ]
+    return ret
+
+
 @reg.register_shape_func("nn.dense", False)
 def dense_shape_func(attrs, inputs, _):
+    """Shape function for dense op. This is an alias of matmul_nt operator for 
data tensor in
+    non-transposed format and weight tensor in transposed format.
     """
-    Shape function for dense op.
-    """
-    ret = [_dense_shape_func(inputs[0], inputs[1])]
+    ret = [
+        _matmul_shape_func(
+            inputs[0],
+            inputs[1],
+            expr.IntImm("bool", False),
+            expr.IntImm("bool", True),
+        )
+    ]
     return ret
 
 
diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py
index bef899e..4c94102 100644
--- a/python/tvm/relay/op/nn/nn.py
+++ b/python/tvm/relay/op/nn/nn.py
@@ -1471,6 +1471,50 @@ def bias_add(data, bias, axis=1):
     return _make.bias_add(data, bias, axis)
 
 
+def matmul(tensor_a, tensor_b, units=None, out_dtype="", transpose_a=False, 
transpose_b=False):
+    """Matmul operator.
+    Applies a linear transformation. The A & B can be transposed.
+
+    .. math::
+
+        `C = A * B`
+
+    Parameters
+    ----------
+    data : tvm.relay.Expr
+        The first input of the operator,
+        of shape `(d_1, d_2, ..., d_n, units_in)` or `(d_1, d_2, ..., 
units_in, d_n)`.
+
+    weight : tvm.relay.Expr
+        The second input expressions, 2-D matrix,
+        of shape `(units_in, units)` or `(units, units_in)`.
+
+    units : Optional[int]
+        Number of hidden units of the matmul transformation.
+
+    out_dtype : Optional[str]
+        Specifies the output data type for mixed precision matmul,
+        of shape `(d_1, d_2, ..., d_n, units)`.
+
+    transpose_a : Optional[bool] = False
+        Whether the data tensor is in transposed format.
+
+    transpose_b : Optional[bool] = False
+        Whether the weight tensor is in transposed format.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The computed result.
+    """
+    # Since currently `nn.dense` has better topi schedule support, will prefer 
to use `dense`
+    # rather than `matmul` for better compatibility
+    if not transpose_a and transpose_b:
+        # TODO(jcf94): Remove this when `nn.matmul` is finnaly ready
+        return dense(tensor_a, tensor_b, units, out_dtype)
+    return _make.matmul(tensor_a, tensor_b, units, out_dtype, transpose_a, 
transpose_b)
+
+
 def dense(data, weight, units=None, out_dtype=""):
     """Dense operator.
     Applies a linear transformation
diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py
index 74c4e2f..780badc 100644
--- a/python/tvm/relay/op/op_attrs.py
+++ b/python/tvm/relay/op/op_attrs.py
@@ -64,6 +64,11 @@ class BiasAddAttrs(Attrs):
     """Atttribute of nn.bias_add"""
 
 
+@tvm._ffi.register_object("relay.attrs.MatmulAttrs")
+class MatmulAttrs(Attrs):
+    """Attributes for nn.matmul"""
+
+
 @tvm._ffi.register_object("relay.attrs.DenseAttrs")
 class DenseAttrs(Attrs):
     """Attributes for nn.dense"""
diff --git a/python/tvm/relay/op/strategy/cuda.py 
b/python/tvm/relay/op/strategy/cuda.py
index 683f3ec..dd265e4 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -698,6 +698,38 @@ def conv1d_transpose_strategy_cuda(attrs, inputs, 
out_type, target):
     return strategy
 
 
+@matmul_strategy.register(["cuda", "gpu"])
+def matmul_strategy_cuda(attrs, inputs, out_type, target):
+    """Matmul cuda strategy."""
+    strategy = _op.OpStrategy()
+
+    if is_auto_scheduler_enabled():
+        strategy.add_implementation(
+            wrap_compute_matmul(topi.nn.matmul),
+            naive_schedule,
+            name="matmul.cuda",
+        )
+    else:
+        logger.warning(
+            "Matmul is not optimized for cuda. Recommend to use cublas for 
better performance."
+        )
+        # Temporary use this as a basic schedule
+        strategy.add_implementation(
+            wrap_compute_matmul(topi.gpu.matmul_default),
+            wrap_topi_schedule(topi.gpu.schedule_matmul_default),
+            name="matmul_default.gpu",
+        )
+
+    if target.kind.name == "cuda" and "cublas" in target.libs:
+        strategy.add_implementation(
+            wrap_compute_matmul(topi.cuda.matmul_cublas),
+            wrap_topi_schedule(topi.cuda.schedule_matmul_cublas),
+            name="matmul_cublas.cuda",
+            plevel=25,
+        )
+    return strategy
+
+
 @dense_strategy.register(["cuda", "gpu"])
 def dense_strategy_cuda(attrs, inputs, out_type, target):
     """dense cuda strategy"""
diff --git a/python/tvm/relay/op/strategy/generic.py 
b/python/tvm/relay/op/strategy/generic.py
index 35e5177..5cb3f65 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -712,6 +712,42 @@ def dilation2d_strategy(attrs, inputs, out_type, target):
     return strategy
 
 
+# matmul
+def wrap_compute_matmul(topi_compute, need_auto_scheduler_layout=False):
+    """wrap matmul topi compute"""
+
+    def _compute_matmul(attrs, inputs, out_type):
+        """Compute definition of matmul"""
+        out_dtype = attrs.out_dtype
+        out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
+        args = [
+            inputs[0],
+            inputs[1],
+            None,
+            out_dtype,
+            attrs.transpose_a,
+            attrs.transpose_b,
+        ]
+        if need_auto_scheduler_layout:
+            args.append(get_auto_scheduler_rewritten_layout(attrs))
+        return [topi_compute(*args)]
+
+    return _compute_matmul
+
+
+@override_native_generic_func("matmul_strategy")
+def matmul_strategy(attrs, inputs, out_type, target):
+    """matmul generic strategy"""
+    logger.warning("matmul is not optimized for this platform.")
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_matmul(topi.nn.matmul),
+        wrap_topi_schedule(topi.generic.schedule_matmul),
+        name="matmul.generic",
+    )
+    return strategy
+
+
 # dense
 def wrap_compute_dense(topi_compute, need_auto_scheduler_layout=False):
     """wrap dense topi compute"""
diff --git a/python/tvm/relay/op/strategy/x86.py 
b/python/tvm/relay/op/strategy/x86.py
index c21ec4d..d09d90a 100644
--- a/python/tvm/relay/op/strategy/x86.py
+++ b/python/tvm/relay/op/strategy/x86.py
@@ -370,6 +370,78 @@ def conv1d_strategy_cpu(attrs, inputs, out_type, target):
     return strategy
 
 
+@matmul_strategy.register("cpu")
+def matmul_strategy_cpu(attrs, inputs, out_type, target):
+    """matmul x86 strategy"""
+    strategy = _op.OpStrategy()
+
+    same_type = inputs[0].dtype == inputs[1].dtype == out_type.dtype
+    dtype = inputs[0].dtype
+    u8s8s32 = dtype == "uint8" and inputs[1].dtype == "int8" and 
out_type.dtype == "int32"
+    if "cblas" in target.libs:
+        length_before = len(strategy.specializations) if 
strategy.specializations else 0
+        with SpecializedCondition(same_type and dtype in ["float32", 
"float64"]):
+            strategy.add_implementation(
+                wrap_compute_matmul(topi.x86.matmul_cblas),
+                wrap_topi_schedule(topi.x86.schedule_matmul_cblas),
+                name="matmul_cblas.x86",
+                plevel=13,
+            )
+        length_after = len(strategy.specializations) if 
strategy.specializations else 0
+        if length_before == length_after:
+            logger.warning(
+                "Currently cblas only support the data type to be float32 or 
float64. Skip."
+            )
+    if "mkl" in target.libs:
+        length_before = len(strategy.specializations) if 
strategy.specializations else 0
+        with SpecializedCondition(same_type and dtype in ["float32", 
"float64"] or u8s8s32):
+            strategy.add_implementation(
+                wrap_compute_matmul(topi.x86.matmul_mkl),
+                wrap_topi_schedule(topi.x86.schedule_matmul_mkl),
+                name="matmul_mkl.x86",
+                plevel=14,
+            )
+        length_after = len(strategy.specializations) if 
strategy.specializations else 0
+        if length_before == length_after:
+            logger.warning(
+                "Currently mkl only support the data type to be float32, 
float64 or input with "
+                "uint8 and int8 while output wiht int32. Skip."
+            )
+    if "mkldnn" in target.libs:
+        length_before = len(strategy.specializations) if 
strategy.specializations else 0
+        with SpecializedCondition(same_type and dtype == "float32"):
+            strategy.add_implementation(
+                wrap_compute_matmul(topi.x86.matmul_mkldnn),
+                wrap_topi_schedule(topi.x86.schedule_matmul_mkldnn),
+                name="matmul_mkldnn.x86",
+                plevel=15,
+            )
+        length_after = len(strategy.specializations) if 
strategy.specializations else 0
+        if length_before == length_after:
+            logger.warning("Currently mkldnn only support the data type to be 
float32. Skip.")
+
+    if is_auto_scheduler_enabled():
+        strategy.add_implementation(
+            wrap_compute_matmul(topi.nn.matmul, 
need_auto_scheduler_layout=True),
+            naive_schedule,
+            name="matmul.generic",
+            plevel=11,
+        )
+    else:
+        # If no cblas/mkl/mkldnn strategy choosed
+        if not strategy.specializations:
+            logger.warning(
+                "Matmul is not optimized for x86. "
+                "Recommend to use cblas/mkl/mkldnn for better performance."
+            )
+        strategy.add_implementation(
+            wrap_compute_matmul(topi.nn.matmul),
+            naive_schedule,
+            name="matmul.generic",
+        )
+    return strategy
+
+
 @dense_strategy.register("cpu")
 def dense_strategy_cpu(attrs, inputs, out_type, target):
     """dense x86 strategy"""
diff --git a/python/tvm/topi/cuda/dense.py b/python/tvm/topi/cuda/dense.py
index 0f410ae..4035dce 100644
--- a/python/tvm/topi/cuda/dense.py
+++ b/python/tvm/topi/cuda/dense.py
@@ -28,18 +28,24 @@ from ..utils import traverse_inline, get_const_tuple
 logger = logging.getLogger("topi")
 
 
-@autotvm.register_topi_compute("dense_cublas.cuda")
-def dense_cublas(cfg, data, weight, bias=None, out_dtype=None):
-    """Dense operator on CUDA with CUBLAS"""
-    assert len(data.shape) == 2 and len(weight.shape) == 2, "only support 
2-dim dense"
+def _matmul_cublas_common(
+    cfg,
+    tensor_a,
+    tensor_b,
+    bias=None,
+    out_dtype=None,
+    transpose_a=False,
+    transpose_b=False,
+):
+    assert len(tensor_a.shape) == 2 and len(tensor_b.shape) == 2, "only 
support 2-dim matmul"
     if bias is not None:
         assert len(bias.shape) == 1
     if out_dtype is None:
-        out_dtype = data.dtype
-    assert out_dtype == data.dtype, "Mixed precision not supported."
-    batch, in_dim = get_const_tuple(data.shape)
-    out_dim, _ = get_const_tuple(weight.shape)
-    matmul = cublas.matmul(data, weight, False, True)
+        out_dtype = tensor_a.dtype
+    assert out_dtype == tensor_a.dtype, "Mixed precision not supported."
+    batch, in_dim = get_const_tuple(tensor_a.shape)
+    out_dim, _ = get_const_tuple(tensor_b.shape)
+    matmul = cublas.matmul(tensor_a, tensor_b, transpose_a, transpose_b)
     if all(isinstance(d, int) for d in [batch, in_dim, out_dim]):
         cfg.add_flop(batch * in_dim * out_dim * 2)
     if bias is not None:
@@ -49,6 +55,32 @@ def dense_cublas(cfg, data, weight, bias=None, 
out_dtype=None):
     return matmul
 
 
+@autotvm.register_topi_compute("matmul_cublas.cuda")
+def matmul_cublas(
+    cfg,
+    tensor_a,
+    tensor_b,
+    bias=None,
+    out_dtype=None,
+    transpose_a=False,
+    transpose_b=False,
+):
+    """Matmul operator on CUDA with CUBLAS"""
+    return _matmul_cublas_common(cfg, tensor_a, tensor_b, bias, out_dtype, 
transpose_a, transpose_b)
+
+
+@autotvm.register_topi_schedule("matmul_cublas.cuda")
+def schedule_matmul_cublas(_, outs):
+    """Schedule matmul operator using CUBLAS"""
+    return generic.schedule_extern(outs)
+
+
+@autotvm.register_topi_compute("dense_cublas.cuda")
+def dense_cublas(cfg, data, weight, bias=None, out_dtype=None):
+    """Dense operator on CUDA with CUBLAS. This is an alias of matmul_nt 
operator."""
+    return _matmul_cublas_common(cfg, data, weight, bias, out_dtype, False, 
True)
+
+
 @autotvm.register_topi_schedule("dense_cublas.cuda")
 def schedule_dense_cublas(_, outs):
     """Schedule dense operator using CUBLAS"""
diff --git a/python/tvm/topi/generic/nn.py b/python/tvm/topi/generic/nn.py
index 04d6490..1b32141 100644
--- a/python/tvm/topi/generic/nn.py
+++ b/python/tvm/topi/generic/nn.py
@@ -580,6 +580,23 @@ def schedule_fast_softmax(outs):
     return _default_schedule(outs, False)
 
 
+def schedule_matmul(outs):
+    """Schedule for matmul
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of matmul
+          in the format of an array of tensors.
+
+    Returns
+    -------
+    sch: Schedule
+        The computation schedule for the op.
+    """
+    return _default_schedule(outs, False)
+
+
 def schedule_dense(outs):
     """Schedule for dense
 
diff --git a/python/tvm/topi/gpu/dense.py b/python/tvm/topi/gpu/dense.py
index 806aa9f..b9009d3 100644
--- a/python/tvm/topi/gpu/dense.py
+++ b/python/tvm/topi/gpu/dense.py
@@ -49,6 +49,36 @@ def schedule_dense_small_batch(cfg, outs):
     return s
 
 
+@autotvm.register_topi_compute("matmul_default.gpu")
+def matmul_default(
+    cfg,
+    tensor_a,
+    tensor_b,
+    bias=None,
+    out_dtype=None,
+    transpose_a=False,
+    transpose_b=False,
+):
+    """Matmul operator on GPU"""
+    return nn.matmul(tensor_a, tensor_b, bias, out_dtype, transpose_a, 
transpose_b)
+
+
+@autotvm.register_topi_schedule("matmul_default.gpu")
+def schedule_matmul_default(cfg, outs):
+    """Schedule matmul on GPU"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if op.tag == "matmul":
+            # Temporary use this as a basic schedule for matmul
+            # TODO(jcf94): Add a more general schedule for matmul
+            _schedule_dense_small_batch(cfg, s, op.output(0))
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
+
+
 def _schedule_dense_small_batch(cfg, s, C):
     A, weights = C.op.input_tensors
     _, in_dim_weights = get_const_tuple(weights.shape)
diff --git a/python/tvm/topi/nn/dense.py b/python/tvm/topi/nn/dense.py
index e8ec476..58c458a 100644
--- a/python/tvm/topi/nn/dense.py
+++ b/python/tvm/topi/nn/dense.py
@@ -21,15 +21,23 @@ from tvm import te, auto_scheduler
 from .. import tag
 
 
-def dense(data, weight, bias=None, out_dtype=None, 
auto_scheduler_rewritten_layout=""):
-    """The default implementation of dense in topi.
+def matmul(
+    tensor_a,
+    tensor_b,
+    bias=None,
+    out_dtype=None,
+    transpose_a=False,
+    transpose_b=False,
+    auto_scheduler_rewritten_layout="",
+):
+    """The default implementation of matmul in topi.
 
     Parameters
     ----------
-    data : tvm.te.Tensor
+    tensor_a : tvm.te.Tensor
         2-D with shape [batch, in_dim]
 
-    weight : tvm.te.Tensor
+    tensor_b : tvm.te.Tensor
         2-D with shape [out_dim, in_dim]
 
     bias : Optional[tvm.te.Tensor]
@@ -38,7 +46,13 @@ def dense(data, weight, bias=None, out_dtype=None, 
auto_scheduler_rewritten_layo
     out_dtype : Optional[str]
         The output type. This is used for mixed precision.
 
-    auto_scheduler_rewritten_layout: str = ""
+    transpose_a : Optional[bool] = False
+        Whether the tensor_a is in transposed format.
+
+    transpose_b : Optional[bool] = False
+        Whether the tensor_b is in transposed format.
+
+    auto_scheduler_rewritten_layout: Optional[str] = ""
         The layout after auto-scheduler's layout rewrite pass.
 
     Returns
@@ -46,42 +60,128 @@ def dense(data, weight, bias=None, out_dtype=None, 
auto_scheduler_rewritten_layo
     output : tvm.te.Tensor
         2-D with shape [batch, out_dim]
     """
-    assert len(data.shape) == 2, "only support 2-dim dense"
+    # TODO(jcf94): Add multi-dim support for tensor_a
+    assert len(tensor_a.shape) == 2, "only support 2-dim matmul"
     if bias is not None:
         assert len(bias.shape) == 1
     if out_dtype is None:
-        out_dtype = data.dtype
-    batch, in_dim = data.shape
+        out_dtype = tensor_a.dtype
+    if transpose_a:
+        in_dim, batch = tensor_a.shape
+    else:
+        batch, in_dim = tensor_a.shape
 
     if auto_scheduler_rewritten_layout:
         # Infer shape for the rewritten layout
         out_dim, red_dim = auto_scheduler.get_shape_from_rewritten_layout(
             auto_scheduler_rewritten_layout, ["j", "k"]
         )
-        auto_scheduler.remove_index_check(weight)
+        auto_scheduler.remove_index_check(tensor_b)
+    elif transpose_b:
+        out_dim, red_dim = tensor_b.shape
     else:
-        out_dim, red_dim = weight.shape
+        red_dim, out_dim = tensor_b.shape
     assert in_dim == red_dim
 
     k = te.reduce_axis((0, in_dim), name="k")
-    matmul = te.compute(
+    if (transpose_a, transpose_b) == (True, True):
+        compute_lambda = lambda i, j: te.sum(
+            tensor_a[k, i].astype(out_dtype) * tensor_b[j, 
k].astype(out_dtype), axis=k
+        )
+        compute_name = "T_matmul_TT"
+        compute_tag = "matmul"
+    elif (transpose_a, transpose_b) == (True, False):
+        compute_lambda = lambda i, j: te.sum(
+            tensor_a[k, i].astype(out_dtype) * tensor_b[k, 
j].astype(out_dtype), axis=k
+        )
+        compute_name = "T_matmul_TN"
+        compute_tag = "matmul"
+    elif (transpose_a, transpose_b) == (False, True):
+        compute_lambda = lambda i, j: te.sum(
+            tensor_a[i, k].astype(out_dtype) * tensor_b[j, 
k].astype(out_dtype), axis=k
+        )
+        compute_name = "T_matmul_NT"
+        # TODO(jcf94): Remove `dense` when `matmul` is finally ready
+        compute_tag = "dense"
+    else:  # (transpose_a, transpose_b) == (False, False):
+        compute_lambda = lambda i, j: te.sum(
+            tensor_a[i, k].astype(out_dtype) * tensor_b[k, 
j].astype(out_dtype), axis=k
+        )
+        compute_name = "T_matmul_NN"
+        compute_tag = "matmul"
+
+    mat = te.compute(
         (batch, out_dim),
-        lambda i, j: te.sum(data[i, k].astype(out_dtype) * weight[j, 
k].astype(out_dtype), axis=k),
-        name="T_dense",
-        tag="dense",
-        attrs={"layout_free_placeholders": [weight]},
+        compute_lambda,
+        name=compute_name,
+        tag=compute_tag,
+        attrs={"layout_free_placeholders": [tensor_b]},
     )
+
     if bias is not None:
-        matmul = te.compute(
+        mat = te.compute(
             (batch, out_dim),
-            lambda i, j: matmul[i, j] + bias[j].astype(out_dtype),
+            lambda i, j: mat[i, j] + bias[j].astype(out_dtype),
             tag=tag.BROADCAST,
         )
 
     if auto_scheduler_rewritten_layout:
-        matmul = auto_scheduler.rewrite_compute_body(matmul, 
auto_scheduler_rewritten_layout)
+        mat = auto_scheduler.rewrite_compute_body(mat, 
auto_scheduler_rewritten_layout)
+
+    return mat
+
+
+@tvm.target.generic_func
+def matmul_legalize(attrs, inputs, types):
+    """Legalizes matmul op.
+
+    Parameters
+    ----------
+    attrs : tvm.ir.Attrs
+        Attributes of current matmul
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    types : list of types
+        List of input and output types
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The legalized expr
+    """
+    # not to change by default
+    # pylint: disable=unused-argument
+    return None
+
+
+def dense(data, weight, bias=None, out_dtype=None, 
auto_scheduler_rewritten_layout=""):
+    """The default implementation of dense in topi.
+    This is an alias of matmul_nt operator for data tensor in non-transposed 
format and weight
+    tensor in transposed format.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        2-D with shape [batch, in_dim]
 
-    return matmul
+    weight : tvm.te.Tensor
+        2-D with shape [out_dim, in_dim]
+
+    bias : Optional[tvm.te.Tensor]
+        1-D with shape [out_dim]
+
+    out_dtype : Optional[str]
+        The output type. This is used for mixed precision.
+
+    auto_scheduler_rewritten_layout: str = ""
+        The layout after auto-scheduler's layout rewrite pass.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        2-D with shape [batch, out_dim]
+    """
+    return matmul(data, weight, bias, out_dtype, False, True, 
auto_scheduler_rewritten_layout)
 
 
 @tvm.target.generic_func
diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py
index 4fed4c1..189ac5b 100644
--- a/python/tvm/topi/x86/dense.py
+++ b/python/tvm/topi/x86/dense.py
@@ -28,7 +28,7 @@ from tvm.contrib import mkldnn
 
 from .utils import get_fp32_len
 from .injective import schedule_injective_from_existing
-from .. import generic, tag
+from .. import tag
 from ..utils import traverse_inline, get_const_tuple
 
 
@@ -281,72 +281,121 @@ def schedule_dense_pack(cfg, outs):
     return s
 
 
-def dense_blas_common(cfg, data, weight, bias, out_dtype, lib):
-    """Compute dense using a BLAS library"""
-    M, K = get_const_tuple(data.shape)
-    N, _ = get_const_tuple(weight.shape)
+def matmul_blas_common(cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, 
transpose_b, lib):
+    """Compute matmul/dense using a BLAS library"""
+    M, K = get_const_tuple(tensor_a.shape)
+    N, _ = get_const_tuple(tensor_b.shape)
     if isinstance(M, int) and isinstance(K, int) and isinstance(N, int):
         cfg.add_flop(M * K * N * 2)
-    if data.dtype == "uint8" and weight.dtype == "int8" and out_dtype == 
"int32":
+    if tensor_a.dtype == "uint8" and tensor_b.dtype == "int8" and out_dtype == 
"int32":
         if not hasattr(lib, "matmul_u8s8s32"):
             raise NotImplementedError(
-                f"Dense with {lib.__name__} for {data.dtype} is not supported "
+                f"Matmul/Dense with {lib.__name__} for {tensor_a.dtype} is not 
supported "
                 "(matmulu8s8s32 not imlemented)"
             )
-        C = lib.matmul_u8s8s32(data, weight, False, True, dtype=out_dtype)
-    elif data.dtype == "float32" or data.dtype == "float64":
-        C = lib.matmul(data, weight, False, True)
+        C = lib.matmul_u8s8s32(tensor_a, tensor_b, transpose_a, transpose_b, 
dtype=out_dtype)
+    elif tensor_a.dtype == "float32" or tensor_a.dtype == "float64":
+        C = lib.matmul(tensor_a, tensor_b, transpose_a, transpose_b)
     else:
-        raise NotImplementedError(f"Dense with {lib.__name__} for {data.dtype} 
is not supported")
+        raise NotImplementedError(
+            f"Matmul/Dense with {lib.__name__} for {tensor_a.dtype} is not 
supported"
+        )
 
     if bias is not None:
         C = te.compute(C.shape, lambda i, j: C[i, j] + 
bias[j].astype(out_dtype), tag=tag.BROADCAST)
     return C
 
 
+def schedule_matmul_blas_common(outs):
+    """Default matmul schedule for BLAS library"""
+    s = te.create_schedule([x.op for x in outs])
+    te.schedule.AutoInlineInjective(s)
+
+    for out in outs:
+        if "dense" not in out.op.tag and "matmul" not in out.op.tag:
+            schedule_injective_from_existing(s, out)
+    return s
+
+
 @autotvm.register_topi_compute("dense_cblas.x86")
 def dense_cblas(cfg, data, weight, bias=None, out_dtype=None):
-    """Compute dense using a cblas"""
-    return dense_blas_common(cfg, data, weight, bias, out_dtype, cblas)
+    """Compute dense using cblas. This is an alias of matmul_nt operator."""
+    return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True, 
cblas)
 
 
 @autotvm.register_topi_schedule("dense_cblas.x86")
 def schedule_dense_cblas(_, outs):
-    """Create schedule for dense_cblas"""
-    return generic.schedule_extern(outs)
+    """Create schedule for dense_cblas. This is an alias of matmul_nt 
operator."""
+    return schedule_matmul_blas_common(outs)
 
 
 @autotvm.register_topi_compute("dense_mkl.x86")
 def dense_mkl(cfg, data, weight, bias=None, out_dtype=None):
-    """Compute dense using mkl"""
-    return dense_blas_common(cfg, data, weight, bias, out_dtype, mkl)
+    """Compute dense using mkl. This is an alias of matmul_nt operator."""
+    return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True, 
mkl)
 
 
 @autotvm.register_topi_schedule("dense_mkl.x86")
 def schedule_dense_mkl(_, outs):
-    """Create schedule for dense_mkl"""
-    # return generic.schedule_extern(outs)
-    s = te.create_schedule([x.op for x in outs])
-    te.schedule.AutoInlineInjective(s)
-
-    def _callback(op):
-        if "broadcast" in op.tag or "injective" in op.tag or "elemwise" in 
op.tag:
-            schedule_injective_from_existing(s, op.output(0))
-
-    # traverse_inline(s, outs[0].op, _callback)
-    for out in outs:
-        if "dense" not in out.op.name:
-            schedule_injective_from_existing(s, out)
-    return s
+    """Create schedule for dense_mkl. This is an alias of matmul_nt 
operator."""
+    return schedule_matmul_blas_common(outs)
 
 
 @autotvm.register_topi_compute("dense_mkldnn.x86")
 def dense_mkldnn(cfg, data, weight, bias=None, out_dtype=None):
-    """Compute dense using mkldnn"""
-    return dense_blas_common(cfg, data, weight, bias, out_dtype, mkldnn)
+    """Compute dense using mkldnn. This is an alias of matmul_nt operator."""
+    return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True, 
mkldnn)
 
 
 @autotvm.register_topi_schedule("dense_mkldnn.x86")
 def schedule_dense_mkldnn(_, outs):
-    """Create schedule for dense_mkldnn"""
-    return generic.schedule_extern(outs)
+    """Create schedule for dense_mkldnn. This is an alias of matmul_nt 
operator."""
+    return schedule_matmul_blas_common(outs)
+
+
+@autotvm.register_topi_compute("matmul_cblas.x86")
+def matmul_cblas(
+    cfg, tensor_a, tensor_b, bias=None, out_dtype=None, transpose_a=False, 
transpose_b=False
+):
+    """Compute matmul using cblas."""
+    return matmul_blas_common(
+        cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b, 
cblas
+    )
+
+
+@autotvm.register_topi_schedule("matmul_cblas.x86")
+def schedule_matmul_cblas(_, outs):
+    """Create schedule for matmul_cblas."""
+    return schedule_matmul_blas_common(outs)
+
+
+@autotvm.register_topi_compute("matmul_mkl.x86")
+def matmul_mkl(
+    cfg, tensor_a, tensor_b, bias=None, out_dtype=None, transpose_a=False, 
transpose_b=False
+):
+    """Compute matmul using mkl."""
+    return matmul_blas_common(
+        cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b, mkl
+    )
+
+
+@autotvm.register_topi_schedule("matmul_mkl.x86")
+def schedule_matmul_mkl(_, outs):
+    """Create schedule for matmul_mkl."""
+    return schedule_matmul_blas_common(outs)
+
+
+@autotvm.register_topi_compute("matmul_mkldnn.x86")
+def matmul_mkldnn(
+    cfg, tensor_a, tensor_b, bias=None, out_dtype=None, transpose_a=False, 
transpose_b=False
+):
+    """Compute matmul using mkldnn."""
+    return matmul_blas_common(
+        cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b, 
mkldnn
+    )
+
+
+@autotvm.register_topi_schedule("matmul_mkldnn.x86")
+def schedule_matmul_mkldnn(_, outs):
+    """Create schedule for matmul_mkldnn."""
+    return schedule_matmul_blas_common(outs)
diff --git a/rust/tvm/src/ir/relay/attrs/nn.rs 
b/rust/tvm/src/ir/relay/attrs/nn.rs
index f0137fa..04320d1 100644
--- a/rust/tvm/src/ir/relay/attrs/nn.rs
+++ b/rust/tvm/src/ir/relay/attrs/nn.rs
@@ -56,6 +56,18 @@ pub struct BiasAddAttrsNode {
 
 #[repr(C)]
 #[derive(Object, Debug)]
+#[ref_name = "MatmulAttrs"]
+#[type_key = "relay.attrs.MatmulAttrs"]
+pub struct MatmulAttrsNode {
+    pub base: BaseAttrsNode,
+    pub units: IndexExpr,
+    pub out_dtype: DataType,
+    pub transpose_a: bool,
+    pub transpose_b: bool,
+}
+
+#[repr(C)]
+#[derive(Object, Debug)]
 #[ref_name = "DenseAttrs"]
 #[type_key = "relay.attrs.DenseAttrs"]
 pub struct DenseAttrsNode {
diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h
index 81de4bc..6f4db5a 100644
--- a/src/relay/op/make_op.h
+++ b/src/relay/op/make_op.h
@@ -44,6 +44,9 @@ Expr MakeClip(Expr a, double a_min, double a_max);
 
 Expr MakeConcatenate(Expr data, int axis);
 
+Expr MakeMatmul(Expr tensor_a, Expr tensor_b, IndexExpr units, DataType 
out_dtype, bool transpose_a,
+                bool transpose_b);
+
 Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype);
 
 Expr MakeBatchMatmul(Expr lhs, Expr rhs, DataType out_dtype);
diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc
index 489be15..4eaa12b 100644
--- a/src/relay/op/nn/nn.cc
+++ b/src/relay/op/nn/nn.cc
@@ -162,7 +162,39 @@ Useful for
     .set_support_level(3)
     .add_type_rel("FIFOBuffer", FIFOBufferRel);
 
-// relay.nn.dense
+// ------------------- relay.nn.matmul
+TVM_REGISTER_NODE_TYPE(MatmulAttrs);
+
+Expr MakeMatmul(Expr tensor_a, Expr tensor_b, IndexExpr units, DataType 
out_dtype, bool transpose_a,
+                bool transpose_b) {
+  auto attrs = make_object<MatmulAttrs>();
+  attrs->units = units;
+  attrs->out_dtype = out_dtype;
+  attrs->transpose_a = transpose_a;
+  attrs->transpose_b = transpose_b;
+  static const Op& matmul_op = Op::Get("nn.matmul");
+  return Call(matmul_op, {tensor_a, tensor_b}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op.nn._make.matmul").set_body_typed(MakeMatmul);
+
+RELAY_REGISTER_OP("nn.matmul")
+    .describe(R"code(Applies a linear transformation: :math:`C = A * B`. A & B 
can be transposed.
+
+- **tensor_a**: `(x1, x2, ..., xn, input_dim)` or `(x1, x2, ..., input_dim, 
xn)`
+- **tensor_b**: `(input_dim, units)` or `(units, input_dim)`
+- **out**: `(x1, x2, ..., xn, units)`.
+
+)code" TVM_ADD_FILELINE)
+    .set_attrs_type<MatmulAttrs>()
+    .set_num_inputs(2)
+    .add_argument("tensor_a", "nD Tensor", "The first input Tensor.")
+    .add_argument("tensor_b", "2D Tensor", "The second input Tensor.")
+    .set_support_level(1)
+    .add_type_rel("Matmul", MatmulRel<MatmulAttrs>);
+// ------------------- relay.nn.matmul
+
+// ------------------- relay.nn.dense
 TVM_REGISTER_NODE_TYPE(DenseAttrs);
 
 // Positional relay function to create dense operator used by frontend FFI.
@@ -189,9 +221,10 @@ RELAY_REGISTER_OP("nn.dense")
     .add_argument("data", "nD Tensor", "Input data.")
     .add_argument("weight", "2D Tensor", "Weight matrix.")
     .set_support_level(1)
-    .add_type_rel("Dense", DenseRel<DenseAttrs>);
+    .add_type_rel("Dense", MatmulRel<DenseAttrs>);
+// ------------------- relay.nn.dense
 
-// relay.nn.contrib_dense_pack
+// ------------------- relay.nn.contrib_dense_pack
 // Positional relay function to create dense_pack operator used by frontend 
FFI.
 Expr MakeDensePack(Expr data, Expr weight, IndexExpr units, DataType 
out_dtype) {
   auto attrs = make_object<DenseAttrs>();
@@ -217,6 +250,7 @@ RELAY_REGISTER_OP("nn.contrib_dense_pack")
     .add_argument("weight", "3D Tensor", "Packed weight matrix.")
     .set_support_level(10)
     .add_type_rel("DensePack", DensePackRel<DenseAttrs>);
+// ------------------- relay.nn.contrib_dense_pack
 
 // relay.leaky_relu
 TVM_REGISTER_NODE_TYPE(LeakyReluAttrs);
diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h
index 1ac800f..29f200c 100644
--- a/src/relay/op/nn/nn.h
+++ b/src/relay/op/nn/nn.h
@@ -36,31 +36,44 @@ namespace tvm {
 namespace relay {
 
 template <typename AttrType>
-bool DenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
-              const TypeReporter& reporter) {
+bool MatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+               const TypeReporter& reporter) {
   ICHECK_EQ(types.size(), 3);
-  const auto* data = types[0].as<TensorTypeNode>();
-  const auto* weight = types[1].as<TensorTypeNode>();
-  if (data == nullptr) return false;
+  const auto* tensor_a = types[0].as<TensorTypeNode>();
+  const auto* tensor_b = types[1].as<TensorTypeNode>();
+  if (tensor_a == nullptr) return false;
+  ICHECK(static_cast<int>(tensor_a->shape.size()) != 0);
 
   const AttrType* param = attrs.as<AttrType>();
   ICHECK(param != nullptr);
+  // Default set to dense layout
+  bool transpose_a = false;
+  bool transpose_b = true;
+  const auto& mattrs = attrs.as<MatmulAttrs>();
+  if (mattrs != nullptr) {
+    transpose_a = mattrs->transpose_a;
+    transpose_b = mattrs->transpose_b;
+  }
 
-  ICHECK(static_cast<int>(data->shape.size()) != 0);
-
-  Array<tvm::PrimExpr> dshape = data->shape;
+  const Array<tvm::PrimExpr>& dshape = tensor_a->shape;
   Array<tvm::PrimExpr> oshape = dshape;
+  tvm::PrimExpr reduce = dshape[dshape.size() - 1];
+  if (transpose_a) {
+    reduce = dshape[dshape.size() - 2];
+    oshape.Set((oshape.size() - 2), dshape[oshape.size() - 1]);
+  }
   if (param->units.defined()) {
-    // validate the weight shape is proper if defined
-    // Assign weight type
-    Array<IndexExpr> wshape({param->units, dshape[dshape.size() - 1]});
-    // It is possible for weight to be nullptr in which case we will use
-    // data dtype as the weight dtype. However if weight dtype is explicitly
+    // validate the tensor_b shape is proper if defined
+    // Assign tensor_b type
+    const Array<IndexExpr>& wshape = transpose_b ? 
Array<IndexExpr>({param->units, reduce})
+                                                 : Array<IndexExpr>({reduce, 
param->units});
+    // It is possible for tensor_b to be nullptr in which case we will use
+    // data dtype as the tensor_b dtype. However if tensor_b dtype is 
explicitly
     // present we will use that.
-    auto weight_dtype = (weight == nullptr ? data->dtype : weight->dtype);
+    auto tensor_b_dtype = (tensor_b == nullptr ? tensor_a->dtype : 
tensor_b->dtype);
     if (param->auto_scheduler_rewritten_layout.size() == 0) {
       // Normal case: assign result to reporter
-      reporter->Assign(types[1], TensorType(wshape, weight_dtype));
+      reporter->Assign(types[1], TensorType(wshape, tensor_b_dtype));
     } else {
       // If the layout is rewritten by auto-scheduler,
       // we just forcly apply the layout provided by auto-scheduler and
@@ -69,31 +82,32 @@ bool DenseRel(const Array<Type>& types, int num_inputs, 
const Attrs& attrs,
     }
     oshape.Set((oshape.size() - 1), param->units);
   } else {
-    if (weight == nullptr) return false;
-    Array<tvm::PrimExpr> wshape = weight->shape;
-    // When weight's layout has been rewritten, figure it out based on the
+    if (tensor_b == nullptr) return false;
+    const Array<tvm::PrimExpr>& wshape = tensor_b->shape;
+    // When tensor_b's layout has been rewritten, figure it out based on the
     // total number of elements and input dimensions.
     if (param->auto_scheduler_rewritten_layout.size() != 0) {
-      PrimExpr weight_elements = 1;
+      PrimExpr tensor_b_elements = 1;
       for (size_t i = 0; i < wshape.size(); i++) {
-        weight_elements = weight_elements * wshape[i];
+        tensor_b_elements = tensor_b_elements * wshape[i];
       }
-      oshape.Set(oshape.size() - 1, weight_elements / dshape[dshape.size() - 
1]);
-      // Otherwise just pull it out of the weight shape directly.
+      oshape.Set(oshape.size() - 1, tensor_b_elements / dshape[dshape.size() - 
1]);
+      // Otherwise just pull it out of the tensor_b shape directly.
     } else {
-      ICHECK(static_cast<int>(weight->shape.size()) == 2);
-      if (!data->shape.back().as<tir::AnyNode>()) {
-        ICHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1], 
weight->shape[1]))
-            << "DenseRel: input dimension doesn't match,"
-            << " data shape=" << data->shape << ", weight shape=" << 
weight->shape;
+      ICHECK(static_cast<int>(tensor_b->shape.size()) == 2);
+      if (!tensor_a->shape.back().as<tir::AnyNode>()) {
+        ICHECK((transpose_b && reporter->AssertEQ(reduce, tensor_b->shape[1])) 
||
+               (!transpose_b && reporter->AssertEQ(reduce, 
tensor_b->shape[0])))
+            << "MatmulRel: input dimension doesn't match,"
+            << " tensor_a shape=" << tensor_a->shape << ", tensor_b shape=" << 
tensor_b->shape;
       }
-      oshape.Set((oshape.size() - 1), wshape[0]);
+      oshape.Set((oshape.size() - 1), transpose_b ? wshape[0] : wshape[1]);
     }
   }
 
   DataType out_dtype = param->out_dtype;
   if (out_dtype.bits() == 0) {
-    out_dtype = data->dtype;
+    out_dtype = tensor_a->dtype;
   }
   // assign output type
   reporter->Assign(types[2], TensorType(oshape, out_dtype));
diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc
index 6284524..592fa77 100644
--- a/src/relay/qnn/op/dense.cc
+++ b/src/relay/qnn/op/dense.cc
@@ -70,7 +70,7 @@ bool QnnDenseRel(const Array<Type>& types, int num_inputs, 
const Attrs& attrs,
   // Collect the input tensor and output tensor devoid of scale and zero 
points to reuse Relay
   // Dense infer type function.
   Array<Type> tensor_types = {types[0], types[1], types[6]};
-  return DenseRel<DenseAttrs>(tensor_types, 3, attrs, reporter);
+  return MatmulRel<DenseAttrs>(tensor_types, 3, attrs, reporter);
 }
 
 // Positional relay function to create quantized dense operator used by 
frontend FFI.
diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc 
b/src/relay/transforms/auto_scheduler_layout_rewrite.cc
index edc4119..da0bd35 100644
--- a/src/relay/transforms/auto_scheduler_layout_rewrite.cc
+++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc
@@ -87,6 +87,8 @@ class FuncMutator : public ExprMutator {
         updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout);
       } else if (auto pattr = call->attrs.as<Conv3DAttrs>()) {
         updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout);
+      } else if (auto pattr = call->attrs.as<MatmulAttrs>()) {
+        updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout);
       } else if (auto pattr = call->attrs.as<DenseAttrs>()) {
         updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout);
       } else if (auto pattr = call->attrs.as<BatchMatmulAttrs>()) {
@@ -103,9 +105,9 @@ class FuncMutator : public ExprMutator {
   std::deque<std::string> ori_layouts_queue_;
   std::deque<std::string> new_layouts_queue_;
 
-  std::vector<std::string> target_ops_{"nn.conv2d", "nn.conv3d",
-                                       
"nn.contrib_conv2d_winograd_without_weight_transform",
-                                       "nn.dense", "nn.batch_matmul"};
+  std::vector<std::string> target_ops_{
+      "nn.conv2d", "nn.conv3d", 
"nn.contrib_conv2d_winograd_without_weight_transform",
+      "nn.matmul", "nn.dense",  "nn.batch_matmul"};
 };
 
 Expr AutoSchedulerLayoutRewriter::VisitExpr_(const CallNode* n) {
@@ -166,6 +168,8 @@ 
TVM_REGISTER_GLOBAL("relay.attrs.get_auto_scheduler_rewritten_layout")
         return 
attrs.as<Conv2DWinogradAttrs>()->auto_scheduler_rewritten_layout;
       } else if (attrs->IsInstance<Conv3DAttrs>()) {
         return attrs.as<Conv3DAttrs>()->auto_scheduler_rewritten_layout;
+      } else if (attrs->IsInstance<MatmulAttrs>()) {
+        return attrs.as<MatmulAttrs>()->auto_scheduler_rewritten_layout;
       } else if (attrs->IsInstance<DenseAttrs>()) {
         return attrs.as<DenseAttrs>()->auto_scheduler_rewritten_layout;
       } else if (attrs->IsInstance<BatchMatmulAttrs>()) {
diff --git a/tests/python/frontend/tensorflow/test_forward.py 
b/tests/python/frontend/tensorflow/test_forward.py
index 136dcab..583014f 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -117,6 +117,7 @@ def run_tvm_graph(
     disabled_pass=None,
     ignore_in_shape=False,
     serialize=False,
+    use_dense_op=True,
 ):
     """Generic function to compile on relay and execute on tvm"""
     input_data = convert_to_list(input_data)
@@ -131,7 +132,11 @@ def run_tvm_graph(
             e: i.shape if hasattr(i, "shape") else () for e, i in 
zip(input_node, input_data)
         }
     mod, params = relay.frontend.from_tensorflow(
-        graph_def, layout=layout, shape=shape_dict, outputs=out_names
+        graph_def,
+        layout=layout,
+        shape=shape_dict,
+        outputs=out_names,
+        use_dense_op=use_dense_op,
     )
     dev = tvm.device(target, 0)
     if mode == "debug":
@@ -213,6 +218,7 @@ def compare_tf_with_tvm(
     add_shapes_to_graph_def=True,
     targets=None,
     ignore_in_shape=False,
+    use_dense_op=True,
 ):
     """Generic function to generate and compare tensorflow and TVM output"""
 
@@ -260,6 +266,7 @@ def compare_tf_with_tvm(
                 mode=mode,
                 cuda_layout=cuda_layout,
                 ignore_in_shape=ignore_in_shape,
+                use_dense_op=use_dense_op,
             )
             # since the names from tensorflow and relay runs are not exactly 
same,
             # first len(tf_output) will be compared
@@ -1795,7 +1802,8 @@ def _test_matmul(i, j, k, dtype, outer=None):
 
                 A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype)
                 B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
-                compare_tf_with_tvm([A_np, B_np], [A.name, B.name], 
result.name)
+                compare_tf_with_tvm([A_np, B_np], [A.name, B.name], 
result.name, use_dense_op=True)
+                compare_tf_with_tvm([A_np, B_np], [A.name, B.name], 
result.name, use_dense_op=False)
 
 
 def test_forward_matmul():
diff --git a/tests/python/relay/test_op_grad_level2.py 
b/tests/python/relay/test_op_grad_level2.py
index 686fd98..c8a9468 100644
--- a/tests/python/relay/test_op_grad_level2.py
+++ b/tests/python/relay/test_op_grad_level2.py
@@ -199,6 +199,22 @@ def test_dense_grad():
     verify_dense_grad((5, 4), (3, 4))
 
 
+def verify_matmul_grad(a_shape, b_shape, transpose_a, transpose_b):
+    tensor_a = relay.var("tensor_a", relay.TensorType(a_shape, "float32"))
+    tensor_b = relay.var("tensor_b", relay.TensorType(b_shape, "float32"))
+    fwd_func = relay.Function(
+        [tensor_a, tensor_b],
+        relay.nn.matmul(tensor_a, tensor_b, transpose_a=transpose_a, 
transpose_b=transpose_b),
+    )
+    check_grad(fwd_func)
+
+
+def test_matmul_grad():
+    verify_matmul_grad((1, 8), (8, 16), False, False)
+    verify_matmul_grad((4, 1), (4, 3), True, False)
+    verify_matmul_grad((4, 5), (3, 4), True, True)
+
+
 def verify_batch_flatten_grad(d_shape):
     data = relay.var("data", relay.TensorType(d_shape, "float32"))
     fwd_func = relay.Function([data], relay.nn.batch_flatten(data))
@@ -216,4 +232,5 @@ if __name__ == "__main__":
     test_global_avg_pool2d_grad()
     test_conv2d_grad()
     test_dense_grad()
+    test_matmul_grad()
     test_batch_flatten_grad()
diff --git a/tests/python/relay/test_op_level1.py 
b/tests/python/relay/test_op_level1.py
index 89475ac..cbc3e7f 100644
--- a/tests/python/relay/test_op_level1.py
+++ b/tests/python/relay/test_op_level1.py
@@ -410,6 +410,66 @@ def test_batch_norm():
 
 
 @pytest.mark.xfail
+def test_matmul_type_check():
+    dtype = "float16"
+    n, c, h, w = 2, 2, 2, 2
+    x = relay.var("x", relay.TensorType((n, c, h, w), dtype))
+    # it should fail since it does not match with m(2)
+    mismatch_w = 3
+    w = relay.var("w", relay.TensorType((mismatch_w, 2), dtype))
+    y = relay.nn.matmul(x, w)
+    yy = run_infer_type(y)
+
+
+@tvm.testing.uses_gpu
+def test_matmul():
+    for dtype in ["float16", "float32"]:
+        # Matmul accuracy for float16 is poor
+        if dtype == "float16":
+            continue
+        n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), 
te.size_var("w")
+        x = relay.var("x", relay.TensorType((n, c, h, w), dtype))
+        w = relay.var("w", relay.TensorType((2, w), dtype))
+        y = relay.nn.matmul(x, w, units=2, transpose_b=True)
+        assert "units=2" in y.astext()
+        yy = run_infer_type(y)
+        assert yy.checked_type == relay.TensorType((n, c, h, 2), dtype)
+
+        n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), 2
+        x = relay.var("x", relay.TensorType((n, c, w, h), dtype))
+        wh, ww = te.size_var("wh"), te.size_var("ww")
+        w = relay.var("w", relay.TensorType((wh, ww), dtype))
+        y = relay.nn.matmul(x, w, transpose_a=True)
+        yy = run_infer_type(y)
+        assert yy.checked_type == relay.TensorType((n, c, h, ww), dtype)
+
+        n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), 2
+        x = relay.var("x", relay.TensorType((n, c, h, w), dtype))
+        w = relay.var("w", relay.IncompleteType())
+        y = relay.nn.matmul(x, w, units=2)
+        yy = run_infer_type(y)
+        assert yy.checked_type == relay.TensorType((n, c, h, 2), dtype)
+
+        x = relay.var("x", shape=(5, 10), dtype=dtype)
+        w = relay.var("w", shape=(5, 2), dtype=dtype)
+        z = relay.nn.matmul(x, w, transpose_a=True)
+
+        # Check result.
+        func = relay.Function([x, w], z)
+        x_data = np.random.rand(5, 10).astype(dtype)
+        w_data = np.random.rand(5, 2).astype(dtype)
+        ref_res = np.dot(x_data.transpose(), w_data)
+
+        for target, dev in tvm.testing.enabled_targets():
+            intrp1 = relay.create_executor("graph", device=dev, target=target)
+            intrp2 = relay.create_executor("debug", device=dev, target=target)
+            op_res1 = intrp1.evaluate(func)(x_data, w_data)
+            tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5)
+            op_res2 = intrp2.evaluate(func)(x_data, w_data)
+            tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=1e-5)
+
+
+@pytest.mark.xfail
 def test_dense_type_check():
     dtype = "float16"
     n, c, h, w = 2, 2, 2, 2
@@ -426,7 +486,7 @@ def test_dense():
     for dtype in ["float16", "float32"]:
         # Dense accuracy for float16 is poor
         if dtype == "float16":
-            return
+            continue
         n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), 
te.size_var("w")
         x = relay.var("x", relay.TensorType((n, c, h, w), dtype))
         w = relay.var("w", relay.TensorType((2, w), dtype))
@@ -506,6 +566,7 @@ if __name__ == "__main__":
     test_log_softmax()
     test_dropout()
     test_batch_norm()
+    test_matmul()
     test_dense()
     test_bitserial_dense()
     test_dense_dtype()
diff --git a/tests/python/topi/python/test_topi_matmul.py 
b/tests/python/topi/python/test_topi_matmul.py
index e5a21a3..de2d4d3 100644
--- a/tests/python/topi/python/test_topi_matmul.py
+++ b/tests/python/topi/python/test_topi_matmul.py
@@ -41,6 +41,31 @@ def with_tvm(lam, *args):
     return out_nd.numpy()
 
 
+def verify_nn_matmul(sa, sb, transp_a, transp_b):
+    a = np.random.uniform(low=-1.0, high=1.0, size=sa).astype(np.float32)
+    b = np.random.uniform(low=-1.0, high=1.0, size=sb).astype(np.float32)
+    c1 = np.matmul(np.transpose(a) if transp_a else a, np.transpose(b) if 
transp_b else b)
+    c2 = with_tvm(
+        lambda A, B: topi.nn.matmul(A, B, transpose_a=transp_a, 
transpose_b=transp_b),
+        a,
+        b,
+    )
+    tvm.testing.assert_allclose(c1, c2, rtol=1e-5, atol=1e-5)
+
+
+def test_nn_matmul():
+    verify_nn_matmul((1, 1), (1, 1), False, False)
+    verify_nn_matmul((1, 1), (1, 1), True, True)
+    verify_nn_matmul((2, 2), (2, 2), False, False)
+    verify_nn_matmul((2, 2), (2, 2), True, True)
+    verify_nn_matmul((2, 3), (3, 5), False, False)
+    verify_nn_matmul((5, 3), (3, 2), False, False)
+    verify_nn_matmul((3, 5), (3, 2), True, False)
+    verify_nn_matmul((3, 5), (2, 3), True, True)
+    verify_nn_matmul((3, 5), (3, 2), True, False)
+    verify_nn_matmul((5, 3), (2, 3), False, True)
+
+
 def verify_matmul(sa, sb, transp_a, transp_b):
     a = np.random.uniform(low=-1.0, high=1.0, size=sa).astype(np.float32)
     b = np.random.uniform(low=-1.0, high=1.0, size=sb).astype(np.float32)
@@ -79,5 +104,6 @@ def test_tensordot():
 
 
 if __name__ == "__main__":
+    test_nn_matmul()
     test_matmul()
     test_tensordot()

Reply via email to