This is an automated email from the ASF dual-hosted git repository.
jcf94 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6d1ced0 [Matmul] Add matmul op (#8234)
6d1ced0 is described below
commit 6d1ced080390d221d1d64f37038cc869e0127f45
Author: Chenfan <[email protected]>
AuthorDate: Wed Jun 30 22:29:43 2021 +0800
[Matmul] Add matmul op (#8234)
* Add Matmul Op
* Recover DenseAttrs
* Add grad for matmul & some update
* Update matmul cuda default schedule
* Add blas support for matmul
* Lint fix add update doc strings
---
include/tvm/relay/attrs/nn.h | 26 ++++
python/tvm/relay/frontend/tensorflow.py | 19 ++-
python/tvm/relay/frontend/tensorflow_ops.py | 20 ++-
python/tvm/relay/op/_tensor_grad.py | 29 +++++
python/tvm/relay/op/nn/_nn.py | 63 ++++++++--
python/tvm/relay/op/nn/nn.py | 44 +++++++
python/tvm/relay/op/op_attrs.py | 5 +
python/tvm/relay/op/strategy/cuda.py | 32 +++++
python/tvm/relay/op/strategy/generic.py | 36 ++++++
python/tvm/relay/op/strategy/x86.py | 72 +++++++++++
python/tvm/topi/cuda/dense.py | 50 ++++++--
python/tvm/topi/generic/nn.py | 17 +++
python/tvm/topi/gpu/dense.py | 30 +++++
python/tvm/topi/nn/dense.py | 138 ++++++++++++++++++---
python/tvm/topi/x86/dense.py | 119 ++++++++++++------
rust/tvm/src/ir/relay/attrs/nn.rs | 12 ++
src/relay/op/make_op.h | 3 +
src/relay/op/nn/nn.cc | 40 +++++-
src/relay/op/nn/nn.h | 72 ++++++-----
src/relay/qnn/op/dense.cc | 2 +-
.../transforms/auto_scheduler_layout_rewrite.cc | 10 +-
tests/python/frontend/tensorflow/test_forward.py | 12 +-
tests/python/relay/test_op_grad_level2.py | 17 +++
tests/python/relay/test_op_level1.py | 63 +++++++++-
tests/python/topi/python/test_topi_matmul.py | 26 ++++
25 files changed, 842 insertions(+), 115 deletions(-)
diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h
index dc20267..3c75745 100644
--- a/include/tvm/relay/attrs/nn.h
+++ b/include/tvm/relay/attrs/nn.h
@@ -961,6 +961,32 @@ struct AvgPool3DAttrs : public
tvm::AttrsNode<AvgPool3DAttrs> {
}
};
+/*! \brief Attributes for matmul operator */
+struct MatmulAttrs : public tvm::AttrsNode<MatmulAttrs> {
+ IndexExpr units;
+ DataType out_dtype;
+ bool transpose_a;
+ bool transpose_b;
+ tvm::String auto_scheduler_rewritten_layout; // The layout after
auto-scheduler's layout rewrite
+
+ TVM_DECLARE_ATTRS(MatmulAttrs, "relay.attrs.MatmulAttrs") {
+ TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense
transformation.");
+
+ // use 0 bits to indicate none.
+ TVM_ATTR_FIELD(out_dtype)
+ .set_default(NullValue<DataType>())
+ .describe("Output data type, set to explicit type under mixed
precision setting");
+
+ TVM_ATTR_FIELD(transpose_a)
+ .set_default(false)
+ .describe("Whether the first input tensor is in transposed format.");
+
+ TVM_ATTR_FIELD(transpose_b)
+ .set_default(false)
+ .describe("Whether the second input tensor is in transposed format.");
+ }
+};
+
/*! \brief Attributes for dense operator */
struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
IndexExpr units;
diff --git a/python/tvm/relay/frontend/tensorflow.py
b/python/tvm/relay/frontend/tensorflow.py
index 0bdec95..e297398 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -44,6 +44,16 @@ from .tensorflow_ops import _get_more_static_shape
__all__ = ["from_tensorflow"]
+# The default configurations of Relay TensorFlow frontend.
+TF_DEFAULT_CONFIGS = {
+ # By default, TVM converts `tf.matmul` to `transpose(weight) + nn.dense`,
which introduces
+ # unnecessary overhead in weight transpose. Change this flag to False to
directly convert to
+ # `nn.matmul` to get rid of the overhead.
+ # However, please note that `nn.matmul` is in experimental so it may have
some performance
+ # issues.
+ "use_dense": True,
+}
+
# compatible operators that do NOT require any conversion.
_identity_list = []
@@ -1204,7 +1214,7 @@ class SubGraphProto(GraphProto):
return func, self._params
-def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
+def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None,
use_dense_op=True):
"""Load tensorflow graph which is a python tensorflow graph object into
relay.
The companion parameters will be handled automatically.
@@ -1222,6 +1232,11 @@ def from_tensorflow(graph, layout="NHWC", shape=None,
outputs=None):
outputs : List of output tensor names (Optional)
if not specified then the last node is assumed as graph output.
+ use_dense_op : bool (Optional) = True
+ Ture to convert `tf.matmul` to `nn.dense`, else to `nn.matmul`.
+ The `nn.dense` op requires the data tensor to be non-transposed and
weight tensor to be
+ transposed, may insert extra `transpose` to the original graph.
+
Returns
-------
mod : tvm.IRModule
@@ -1230,6 +1245,8 @@ def from_tensorflow(graph, layout="NHWC", shape=None,
outputs=None):
params : dict of str to tvm.nd.NDArray
Dict of converted parameters stored in tvm.nd.NDArray format
"""
+ global TF_DEFAULT_CONFIGS
+ TF_DEFAULT_CONFIGS["use_dense"] = use_dense_op
g = GraphProto()
mod, params = g.from_tensorflow(graph, layout, shape, outputs)
diff --git a/python/tvm/relay/frontend/tensorflow_ops.py
b/python/tvm/relay/frontend/tensorflow_ops.py
index be15f83..004174f 100644
--- a/python/tvm/relay/frontend/tensorflow_ops.py
+++ b/python/tvm/relay/frontend/tensorflow_ops.py
@@ -1113,13 +1113,23 @@ def _no_op():
def _matmul():
def _impl(inputs, attr, params, mod):
+ from .tensorflow import TF_DEFAULT_CONFIGS
+
channels = _infer_channels(inputs[1], not attr["transpose_b"])
- if attr["transpose_a"]:
- inputs[0] = _op.transpose(inputs[0], axes=(1, 0))
- if not attr["transpose_b"]:
- inputs[1] = _op.transpose(inputs[1], axes=(1, 0))
+ if TF_DEFAULT_CONFIGS["use_dense"]:
+ if attr["transpose_a"]:
+ inputs[0] = _op.transpose(inputs[0], axes=(1, 0))
+ if not attr["transpose_b"]:
+ inputs[1] = _op.transpose(inputs[1], axes=(1, 0))
+ return AttrCvt(
+ op_name="dense",
+ extras={"units": channels},
+ ignores=["transpose_a", "transpose_b", "T"],
+ )(inputs, attr)
return AttrCvt(
- op_name="dense", extras={"units": channels},
ignores=["transpose_a", "transpose_b", "T"]
+ op_name="matmul",
+ extras={"units": channels},
+ ignores=["T"],
)(inputs, attr)
return _impl
diff --git a/python/tvm/relay/op/_tensor_grad.py
b/python/tvm/relay/op/_tensor_grad.py
index 09b1435..fa2772c 100644
--- a/python/tvm/relay/op/_tensor_grad.py
+++ b/python/tvm/relay/op/_tensor_grad.py
@@ -554,6 +554,35 @@ def dense_grad(orig, grad):
]
+@register_gradient("nn.matmul")
+def matmul_grad(orig, grad):
+ """Returns [grad' @ tensor_b, tensor_a @ grad']"""
+ tensor_a, tensor_b = orig.args
+ if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (True, True):
+ return [
+ collapse_sum_like(
+ _nn.matmul(tensor_b, grad, transpose_a=True,
transpose_b=True), tensor_a
+ ),
+ collapse_sum_like(
+ _nn.matmul(grad, tensor_a, transpose_a=True,
transpose_b=True), tensor_b
+ ),
+ ]
+ if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (True, False):
+ return [
+ collapse_sum_like(_nn.matmul(tensor_b, grad, transpose_b=True),
tensor_a),
+ collapse_sum_like(_nn.matmul(tensor_a, grad), tensor_b),
+ ]
+ if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (False, True):
+ # Keep using Dense op here for not involving extra ops
+ # TODO(jcf94): Merge all to nn.matmul when it is finally ready
+ return dense_grad(orig, grad)
+ # (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (False, False)
+ return [
+ collapse_sum_like(_nn.matmul(grad, tensor_b, transpose_b=True),
tensor_a),
+ collapse_sum_like(_nn.matmul(tensor_a, grad, transpose_a=True),
tensor_b),
+ ]
+
+
@register_gradient("nn.batch_matmul")
def batch_matmul_grad(orig, grad):
"""gradient for nn.batch_matmul: in einsum LHS_bik,RHS_bjk->RES_bij
diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py
index 04d38ce..056cb56 100644
--- a/python/tvm/relay/op/nn/_nn.py
+++ b/python/tvm/relay/op/nn/_nn.py
@@ -52,6 +52,32 @@ reg.register_schedule("nn.log_softmax",
strategy.schedule_log_softmax)
reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)
[email protected]_legalize("nn.matmul")
+def legalize_matmul(attrs, inputs, types):
+ """Legalize matmul op.
+
+ Parameters
+ ----------
+ attrs : tvm.ir.Attrs
+ Attributes of current matmul
+ inputs : list of tvm.relay.Expr
+ The args of the Relay expr to be legalized
+ types : list of types
+ List of input and output types
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The legalized expr
+ """
+ return topi.nn.matmul_legalize(attrs, inputs, types)
+
+
+# matmul
+reg.register_strategy("nn.matmul", strategy.matmul_strategy)
+reg.register_pattern("nn.matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
+
+
@reg.register_legalize("nn.dense")
def legalize_dense(attrs, inputs, types):
"""Legalize dense op.
@@ -1160,21 +1186,44 @@ def batch_flatten_shape_func(attrs, inputs, _):
@script
-def _dense_shape_func(data_shape, weight_shape):
- out = output_tensor((data_shape.shape[0],), "int64")
+def _matmul_shape_func(tensor_a_shape, tensor_b_shape, transpose_a,
transpose_b):
+ out = output_tensor((tensor_a_shape.shape[0],), "int64")
for i in const_range(out.shape[0] - 1):
- out[i] = data_shape[i]
- out[out.shape[0] - 1] = weight_shape[0]
+ out[i] = tensor_a_shape[i]
+ if transpose_a:
+ out[out.shape[0] - 2] = out[out.shape[0] - 1]
+ out[out.shape[0] - 1] = tensor_b_shape[0] if transpose_b else
tensor_b_shape[1]
return out
[email protected]_shape_func("nn.matmul", False)
+def matmul_shape_func(attrs, inputs, _):
+ """Shape function for matmul op."""
+ ret = [
+ _matmul_shape_func(
+ inputs[0],
+ inputs[1],
+ expr.IntImm("bool", attrs.transpose_a),
+ expr.IntImm("bool", attrs.transpose_b),
+ )
+ ]
+ return ret
+
+
@reg.register_shape_func("nn.dense", False)
def dense_shape_func(attrs, inputs, _):
+ """Shape function for dense op. This is an alias of matmul_nt operator for
data tensor in
+ non-transposed format and weight tensor in transposed format.
"""
- Shape function for dense op.
- """
- ret = [_dense_shape_func(inputs[0], inputs[1])]
+ ret = [
+ _matmul_shape_func(
+ inputs[0],
+ inputs[1],
+ expr.IntImm("bool", False),
+ expr.IntImm("bool", True),
+ )
+ ]
return ret
diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py
index bef899e..4c94102 100644
--- a/python/tvm/relay/op/nn/nn.py
+++ b/python/tvm/relay/op/nn/nn.py
@@ -1471,6 +1471,50 @@ def bias_add(data, bias, axis=1):
return _make.bias_add(data, bias, axis)
+def matmul(tensor_a, tensor_b, units=None, out_dtype="", transpose_a=False,
transpose_b=False):
+ """Matmul operator.
+ Applies a linear transformation. The A & B can be transposed.
+
+ .. math::
+
+ `C = A * B`
+
+ Parameters
+ ----------
+ data : tvm.relay.Expr
+ The first input of the operator,
+ of shape `(d_1, d_2, ..., d_n, units_in)` or `(d_1, d_2, ...,
units_in, d_n)`.
+
+ weight : tvm.relay.Expr
+ The second input expressions, 2-D matrix,
+ of shape `(units_in, units)` or `(units, units_in)`.
+
+ units : Optional[int]
+ Number of hidden units of the matmul transformation.
+
+ out_dtype : Optional[str]
+ Specifies the output data type for mixed precision matmul,
+ of shape `(d_1, d_2, ..., d_n, units)`.
+
+ transpose_a : Optional[bool] = False
+ Whether the data tensor is in transposed format.
+
+ transpose_b : Optional[bool] = False
+ Whether the weight tensor is in transposed format.
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The computed result.
+ """
+ # Since currently `nn.dense` has better topi schedule support, will prefer
to use `dense`
+ # rather than `matmul` for better compatibility
+ if not transpose_a and transpose_b:
+ # TODO(jcf94): Remove this when `nn.matmul` is finnaly ready
+ return dense(tensor_a, tensor_b, units, out_dtype)
+ return _make.matmul(tensor_a, tensor_b, units, out_dtype, transpose_a,
transpose_b)
+
+
def dense(data, weight, units=None, out_dtype=""):
"""Dense operator.
Applies a linear transformation
diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py
index 74c4e2f..780badc 100644
--- a/python/tvm/relay/op/op_attrs.py
+++ b/python/tvm/relay/op/op_attrs.py
@@ -64,6 +64,11 @@ class BiasAddAttrs(Attrs):
"""Atttribute of nn.bias_add"""
+@tvm._ffi.register_object("relay.attrs.MatmulAttrs")
+class MatmulAttrs(Attrs):
+ """Attributes for nn.matmul"""
+
+
@tvm._ffi.register_object("relay.attrs.DenseAttrs")
class DenseAttrs(Attrs):
"""Attributes for nn.dense"""
diff --git a/python/tvm/relay/op/strategy/cuda.py
b/python/tvm/relay/op/strategy/cuda.py
index 683f3ec..dd265e4 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -698,6 +698,38 @@ def conv1d_transpose_strategy_cuda(attrs, inputs,
out_type, target):
return strategy
+@matmul_strategy.register(["cuda", "gpu"])
+def matmul_strategy_cuda(attrs, inputs, out_type, target):
+ """Matmul cuda strategy."""
+ strategy = _op.OpStrategy()
+
+ if is_auto_scheduler_enabled():
+ strategy.add_implementation(
+ wrap_compute_matmul(topi.nn.matmul),
+ naive_schedule,
+ name="matmul.cuda",
+ )
+ else:
+ logger.warning(
+ "Matmul is not optimized for cuda. Recommend to use cublas for
better performance."
+ )
+ # Temporary use this as a basic schedule
+ strategy.add_implementation(
+ wrap_compute_matmul(topi.gpu.matmul_default),
+ wrap_topi_schedule(topi.gpu.schedule_matmul_default),
+ name="matmul_default.gpu",
+ )
+
+ if target.kind.name == "cuda" and "cublas" in target.libs:
+ strategy.add_implementation(
+ wrap_compute_matmul(topi.cuda.matmul_cublas),
+ wrap_topi_schedule(topi.cuda.schedule_matmul_cublas),
+ name="matmul_cublas.cuda",
+ plevel=25,
+ )
+ return strategy
+
+
@dense_strategy.register(["cuda", "gpu"])
def dense_strategy_cuda(attrs, inputs, out_type, target):
"""dense cuda strategy"""
diff --git a/python/tvm/relay/op/strategy/generic.py
b/python/tvm/relay/op/strategy/generic.py
index 35e5177..5cb3f65 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -712,6 +712,42 @@ def dilation2d_strategy(attrs, inputs, out_type, target):
return strategy
+# matmul
+def wrap_compute_matmul(topi_compute, need_auto_scheduler_layout=False):
+ """wrap matmul topi compute"""
+
+ def _compute_matmul(attrs, inputs, out_type):
+ """Compute definition of matmul"""
+ out_dtype = attrs.out_dtype
+ out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
+ args = [
+ inputs[0],
+ inputs[1],
+ None,
+ out_dtype,
+ attrs.transpose_a,
+ attrs.transpose_b,
+ ]
+ if need_auto_scheduler_layout:
+ args.append(get_auto_scheduler_rewritten_layout(attrs))
+ return [topi_compute(*args)]
+
+ return _compute_matmul
+
+
+@override_native_generic_func("matmul_strategy")
+def matmul_strategy(attrs, inputs, out_type, target):
+ """matmul generic strategy"""
+ logger.warning("matmul is not optimized for this platform.")
+ strategy = _op.OpStrategy()
+ strategy.add_implementation(
+ wrap_compute_matmul(topi.nn.matmul),
+ wrap_topi_schedule(topi.generic.schedule_matmul),
+ name="matmul.generic",
+ )
+ return strategy
+
+
# dense
def wrap_compute_dense(topi_compute, need_auto_scheduler_layout=False):
"""wrap dense topi compute"""
diff --git a/python/tvm/relay/op/strategy/x86.py
b/python/tvm/relay/op/strategy/x86.py
index c21ec4d..d09d90a 100644
--- a/python/tvm/relay/op/strategy/x86.py
+++ b/python/tvm/relay/op/strategy/x86.py
@@ -370,6 +370,78 @@ def conv1d_strategy_cpu(attrs, inputs, out_type, target):
return strategy
+@matmul_strategy.register("cpu")
+def matmul_strategy_cpu(attrs, inputs, out_type, target):
+ """matmul x86 strategy"""
+ strategy = _op.OpStrategy()
+
+ same_type = inputs[0].dtype == inputs[1].dtype == out_type.dtype
+ dtype = inputs[0].dtype
+ u8s8s32 = dtype == "uint8" and inputs[1].dtype == "int8" and
out_type.dtype == "int32"
+ if "cblas" in target.libs:
+ length_before = len(strategy.specializations) if
strategy.specializations else 0
+ with SpecializedCondition(same_type and dtype in ["float32",
"float64"]):
+ strategy.add_implementation(
+ wrap_compute_matmul(topi.x86.matmul_cblas),
+ wrap_topi_schedule(topi.x86.schedule_matmul_cblas),
+ name="matmul_cblas.x86",
+ plevel=13,
+ )
+ length_after = len(strategy.specializations) if
strategy.specializations else 0
+ if length_before == length_after:
+ logger.warning(
+ "Currently cblas only support the data type to be float32 or
float64. Skip."
+ )
+ if "mkl" in target.libs:
+ length_before = len(strategy.specializations) if
strategy.specializations else 0
+ with SpecializedCondition(same_type and dtype in ["float32",
"float64"] or u8s8s32):
+ strategy.add_implementation(
+ wrap_compute_matmul(topi.x86.matmul_mkl),
+ wrap_topi_schedule(topi.x86.schedule_matmul_mkl),
+ name="matmul_mkl.x86",
+ plevel=14,
+ )
+ length_after = len(strategy.specializations) if
strategy.specializations else 0
+ if length_before == length_after:
+ logger.warning(
+ "Currently mkl only support the data type to be float32,
float64 or input with "
+ "uint8 and int8 while output wiht int32. Skip."
+ )
+ if "mkldnn" in target.libs:
+ length_before = len(strategy.specializations) if
strategy.specializations else 0
+ with SpecializedCondition(same_type and dtype == "float32"):
+ strategy.add_implementation(
+ wrap_compute_matmul(topi.x86.matmul_mkldnn),
+ wrap_topi_schedule(topi.x86.schedule_matmul_mkldnn),
+ name="matmul_mkldnn.x86",
+ plevel=15,
+ )
+ length_after = len(strategy.specializations) if
strategy.specializations else 0
+ if length_before == length_after:
+ logger.warning("Currently mkldnn only support the data type to be
float32. Skip.")
+
+ if is_auto_scheduler_enabled():
+ strategy.add_implementation(
+ wrap_compute_matmul(topi.nn.matmul,
need_auto_scheduler_layout=True),
+ naive_schedule,
+ name="matmul.generic",
+ plevel=11,
+ )
+ else:
+ # If no cblas/mkl/mkldnn strategy choosed
+ if not strategy.specializations:
+ logger.warning(
+ "Matmul is not optimized for x86. "
+ "Recommend to use cblas/mkl/mkldnn for better performance."
+ )
+ strategy.add_implementation(
+ wrap_compute_matmul(topi.nn.matmul),
+ naive_schedule,
+ name="matmul.generic",
+ )
+ return strategy
+
+
@dense_strategy.register("cpu")
def dense_strategy_cpu(attrs, inputs, out_type, target):
"""dense x86 strategy"""
diff --git a/python/tvm/topi/cuda/dense.py b/python/tvm/topi/cuda/dense.py
index 0f410ae..4035dce 100644
--- a/python/tvm/topi/cuda/dense.py
+++ b/python/tvm/topi/cuda/dense.py
@@ -28,18 +28,24 @@ from ..utils import traverse_inline, get_const_tuple
logger = logging.getLogger("topi")
[email protected]_topi_compute("dense_cublas.cuda")
-def dense_cublas(cfg, data, weight, bias=None, out_dtype=None):
- """Dense operator on CUDA with CUBLAS"""
- assert len(data.shape) == 2 and len(weight.shape) == 2, "only support
2-dim dense"
+def _matmul_cublas_common(
+ cfg,
+ tensor_a,
+ tensor_b,
+ bias=None,
+ out_dtype=None,
+ transpose_a=False,
+ transpose_b=False,
+):
+ assert len(tensor_a.shape) == 2 and len(tensor_b.shape) == 2, "only
support 2-dim matmul"
if bias is not None:
assert len(bias.shape) == 1
if out_dtype is None:
- out_dtype = data.dtype
- assert out_dtype == data.dtype, "Mixed precision not supported."
- batch, in_dim = get_const_tuple(data.shape)
- out_dim, _ = get_const_tuple(weight.shape)
- matmul = cublas.matmul(data, weight, False, True)
+ out_dtype = tensor_a.dtype
+ assert out_dtype == tensor_a.dtype, "Mixed precision not supported."
+ batch, in_dim = get_const_tuple(tensor_a.shape)
+ out_dim, _ = get_const_tuple(tensor_b.shape)
+ matmul = cublas.matmul(tensor_a, tensor_b, transpose_a, transpose_b)
if all(isinstance(d, int) for d in [batch, in_dim, out_dim]):
cfg.add_flop(batch * in_dim * out_dim * 2)
if bias is not None:
@@ -49,6 +55,32 @@ def dense_cublas(cfg, data, weight, bias=None,
out_dtype=None):
return matmul
[email protected]_topi_compute("matmul_cublas.cuda")
+def matmul_cublas(
+ cfg,
+ tensor_a,
+ tensor_b,
+ bias=None,
+ out_dtype=None,
+ transpose_a=False,
+ transpose_b=False,
+):
+ """Matmul operator on CUDA with CUBLAS"""
+ return _matmul_cublas_common(cfg, tensor_a, tensor_b, bias, out_dtype,
transpose_a, transpose_b)
+
+
[email protected]_topi_schedule("matmul_cublas.cuda")
+def schedule_matmul_cublas(_, outs):
+ """Schedule matmul operator using CUBLAS"""
+ return generic.schedule_extern(outs)
+
+
[email protected]_topi_compute("dense_cublas.cuda")
+def dense_cublas(cfg, data, weight, bias=None, out_dtype=None):
+ """Dense operator on CUDA with CUBLAS. This is an alias of matmul_nt
operator."""
+ return _matmul_cublas_common(cfg, data, weight, bias, out_dtype, False,
True)
+
+
@autotvm.register_topi_schedule("dense_cublas.cuda")
def schedule_dense_cublas(_, outs):
"""Schedule dense operator using CUBLAS"""
diff --git a/python/tvm/topi/generic/nn.py b/python/tvm/topi/generic/nn.py
index 04d6490..1b32141 100644
--- a/python/tvm/topi/generic/nn.py
+++ b/python/tvm/topi/generic/nn.py
@@ -580,6 +580,23 @@ def schedule_fast_softmax(outs):
return _default_schedule(outs, False)
+def schedule_matmul(outs):
+ """Schedule for matmul
+
+ Parameters
+ ----------
+ outs: Array of Tensor
+ The computation graph description of matmul
+ in the format of an array of tensors.
+
+ Returns
+ -------
+ sch: Schedule
+ The computation schedule for the op.
+ """
+ return _default_schedule(outs, False)
+
+
def schedule_dense(outs):
"""Schedule for dense
diff --git a/python/tvm/topi/gpu/dense.py b/python/tvm/topi/gpu/dense.py
index 806aa9f..b9009d3 100644
--- a/python/tvm/topi/gpu/dense.py
+++ b/python/tvm/topi/gpu/dense.py
@@ -49,6 +49,36 @@ def schedule_dense_small_batch(cfg, outs):
return s
[email protected]_topi_compute("matmul_default.gpu")
+def matmul_default(
+ cfg,
+ tensor_a,
+ tensor_b,
+ bias=None,
+ out_dtype=None,
+ transpose_a=False,
+ transpose_b=False,
+):
+ """Matmul operator on GPU"""
+ return nn.matmul(tensor_a, tensor_b, bias, out_dtype, transpose_a,
transpose_b)
+
+
[email protected]_topi_schedule("matmul_default.gpu")
+def schedule_matmul_default(cfg, outs):
+ """Schedule matmul on GPU"""
+ outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+ s = te.create_schedule([x.op for x in outs])
+
+ def _callback(op):
+ if op.tag == "matmul":
+ # Temporary use this as a basic schedule for matmul
+ # TODO(jcf94): Add a more general schedule for matmul
+ _schedule_dense_small_batch(cfg, s, op.output(0))
+
+ traverse_inline(s, outs[0].op, _callback)
+ return s
+
+
def _schedule_dense_small_batch(cfg, s, C):
A, weights = C.op.input_tensors
_, in_dim_weights = get_const_tuple(weights.shape)
diff --git a/python/tvm/topi/nn/dense.py b/python/tvm/topi/nn/dense.py
index e8ec476..58c458a 100644
--- a/python/tvm/topi/nn/dense.py
+++ b/python/tvm/topi/nn/dense.py
@@ -21,15 +21,23 @@ from tvm import te, auto_scheduler
from .. import tag
-def dense(data, weight, bias=None, out_dtype=None,
auto_scheduler_rewritten_layout=""):
- """The default implementation of dense in topi.
+def matmul(
+ tensor_a,
+ tensor_b,
+ bias=None,
+ out_dtype=None,
+ transpose_a=False,
+ transpose_b=False,
+ auto_scheduler_rewritten_layout="",
+):
+ """The default implementation of matmul in topi.
Parameters
----------
- data : tvm.te.Tensor
+ tensor_a : tvm.te.Tensor
2-D with shape [batch, in_dim]
- weight : tvm.te.Tensor
+ tensor_b : tvm.te.Tensor
2-D with shape [out_dim, in_dim]
bias : Optional[tvm.te.Tensor]
@@ -38,7 +46,13 @@ def dense(data, weight, bias=None, out_dtype=None,
auto_scheduler_rewritten_layo
out_dtype : Optional[str]
The output type. This is used for mixed precision.
- auto_scheduler_rewritten_layout: str = ""
+ transpose_a : Optional[bool] = False
+ Whether the tensor_a is in transposed format.
+
+ transpose_b : Optional[bool] = False
+ Whether the tensor_b is in transposed format.
+
+ auto_scheduler_rewritten_layout: Optional[str] = ""
The layout after auto-scheduler's layout rewrite pass.
Returns
@@ -46,42 +60,128 @@ def dense(data, weight, bias=None, out_dtype=None,
auto_scheduler_rewritten_layo
output : tvm.te.Tensor
2-D with shape [batch, out_dim]
"""
- assert len(data.shape) == 2, "only support 2-dim dense"
+ # TODO(jcf94): Add multi-dim support for tensor_a
+ assert len(tensor_a.shape) == 2, "only support 2-dim matmul"
if bias is not None:
assert len(bias.shape) == 1
if out_dtype is None:
- out_dtype = data.dtype
- batch, in_dim = data.shape
+ out_dtype = tensor_a.dtype
+ if transpose_a:
+ in_dim, batch = tensor_a.shape
+ else:
+ batch, in_dim = tensor_a.shape
if auto_scheduler_rewritten_layout:
# Infer shape for the rewritten layout
out_dim, red_dim = auto_scheduler.get_shape_from_rewritten_layout(
auto_scheduler_rewritten_layout, ["j", "k"]
)
- auto_scheduler.remove_index_check(weight)
+ auto_scheduler.remove_index_check(tensor_b)
+ elif transpose_b:
+ out_dim, red_dim = tensor_b.shape
else:
- out_dim, red_dim = weight.shape
+ red_dim, out_dim = tensor_b.shape
assert in_dim == red_dim
k = te.reduce_axis((0, in_dim), name="k")
- matmul = te.compute(
+ if (transpose_a, transpose_b) == (True, True):
+ compute_lambda = lambda i, j: te.sum(
+ tensor_a[k, i].astype(out_dtype) * tensor_b[j,
k].astype(out_dtype), axis=k
+ )
+ compute_name = "T_matmul_TT"
+ compute_tag = "matmul"
+ elif (transpose_a, transpose_b) == (True, False):
+ compute_lambda = lambda i, j: te.sum(
+ tensor_a[k, i].astype(out_dtype) * tensor_b[k,
j].astype(out_dtype), axis=k
+ )
+ compute_name = "T_matmul_TN"
+ compute_tag = "matmul"
+ elif (transpose_a, transpose_b) == (False, True):
+ compute_lambda = lambda i, j: te.sum(
+ tensor_a[i, k].astype(out_dtype) * tensor_b[j,
k].astype(out_dtype), axis=k
+ )
+ compute_name = "T_matmul_NT"
+ # TODO(jcf94): Remove `dense` when `matmul` is finally ready
+ compute_tag = "dense"
+ else: # (transpose_a, transpose_b) == (False, False):
+ compute_lambda = lambda i, j: te.sum(
+ tensor_a[i, k].astype(out_dtype) * tensor_b[k,
j].astype(out_dtype), axis=k
+ )
+ compute_name = "T_matmul_NN"
+ compute_tag = "matmul"
+
+ mat = te.compute(
(batch, out_dim),
- lambda i, j: te.sum(data[i, k].astype(out_dtype) * weight[j,
k].astype(out_dtype), axis=k),
- name="T_dense",
- tag="dense",
- attrs={"layout_free_placeholders": [weight]},
+ compute_lambda,
+ name=compute_name,
+ tag=compute_tag,
+ attrs={"layout_free_placeholders": [tensor_b]},
)
+
if bias is not None:
- matmul = te.compute(
+ mat = te.compute(
(batch, out_dim),
- lambda i, j: matmul[i, j] + bias[j].astype(out_dtype),
+ lambda i, j: mat[i, j] + bias[j].astype(out_dtype),
tag=tag.BROADCAST,
)
if auto_scheduler_rewritten_layout:
- matmul = auto_scheduler.rewrite_compute_body(matmul,
auto_scheduler_rewritten_layout)
+ mat = auto_scheduler.rewrite_compute_body(mat,
auto_scheduler_rewritten_layout)
+
+ return mat
+
+
[email protected]_func
+def matmul_legalize(attrs, inputs, types):
+ """Legalizes matmul op.
+
+ Parameters
+ ----------
+ attrs : tvm.ir.Attrs
+ Attributes of current matmul
+ inputs : list of tvm.relay.Expr
+ The args of the Relay expr to be legalized
+ types : list of types
+ List of input and output types
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The legalized expr
+ """
+ # not to change by default
+ # pylint: disable=unused-argument
+ return None
+
+
+def dense(data, weight, bias=None, out_dtype=None,
auto_scheduler_rewritten_layout=""):
+ """The default implementation of dense in topi.
+ This is an alias of matmul_nt operator for data tensor in non-transposed
format and weight
+ tensor in transposed format.
+
+ Parameters
+ ----------
+ data : tvm.te.Tensor
+ 2-D with shape [batch, in_dim]
- return matmul
+ weight : tvm.te.Tensor
+ 2-D with shape [out_dim, in_dim]
+
+ bias : Optional[tvm.te.Tensor]
+ 1-D with shape [out_dim]
+
+ out_dtype : Optional[str]
+ The output type. This is used for mixed precision.
+
+ auto_scheduler_rewritten_layout: str = ""
+ The layout after auto-scheduler's layout rewrite pass.
+
+ Returns
+ -------
+ output : tvm.te.Tensor
+ 2-D with shape [batch, out_dim]
+ """
+ return matmul(data, weight, bias, out_dtype, False, True,
auto_scheduler_rewritten_layout)
@tvm.target.generic_func
diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py
index 4fed4c1..189ac5b 100644
--- a/python/tvm/topi/x86/dense.py
+++ b/python/tvm/topi/x86/dense.py
@@ -28,7 +28,7 @@ from tvm.contrib import mkldnn
from .utils import get_fp32_len
from .injective import schedule_injective_from_existing
-from .. import generic, tag
+from .. import tag
from ..utils import traverse_inline, get_const_tuple
@@ -281,72 +281,121 @@ def schedule_dense_pack(cfg, outs):
return s
-def dense_blas_common(cfg, data, weight, bias, out_dtype, lib):
- """Compute dense using a BLAS library"""
- M, K = get_const_tuple(data.shape)
- N, _ = get_const_tuple(weight.shape)
+def matmul_blas_common(cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a,
transpose_b, lib):
+ """Compute matmul/dense using a BLAS library"""
+ M, K = get_const_tuple(tensor_a.shape)
+ N, _ = get_const_tuple(tensor_b.shape)
if isinstance(M, int) and isinstance(K, int) and isinstance(N, int):
cfg.add_flop(M * K * N * 2)
- if data.dtype == "uint8" and weight.dtype == "int8" and out_dtype ==
"int32":
+ if tensor_a.dtype == "uint8" and tensor_b.dtype == "int8" and out_dtype ==
"int32":
if not hasattr(lib, "matmul_u8s8s32"):
raise NotImplementedError(
- f"Dense with {lib.__name__} for {data.dtype} is not supported "
+ f"Matmul/Dense with {lib.__name__} for {tensor_a.dtype} is not
supported "
"(matmulu8s8s32 not imlemented)"
)
- C = lib.matmul_u8s8s32(data, weight, False, True, dtype=out_dtype)
- elif data.dtype == "float32" or data.dtype == "float64":
- C = lib.matmul(data, weight, False, True)
+ C = lib.matmul_u8s8s32(tensor_a, tensor_b, transpose_a, transpose_b,
dtype=out_dtype)
+ elif tensor_a.dtype == "float32" or tensor_a.dtype == "float64":
+ C = lib.matmul(tensor_a, tensor_b, transpose_a, transpose_b)
else:
- raise NotImplementedError(f"Dense with {lib.__name__} for {data.dtype}
is not supported")
+ raise NotImplementedError(
+ f"Matmul/Dense with {lib.__name__} for {tensor_a.dtype} is not
supported"
+ )
if bias is not None:
C = te.compute(C.shape, lambda i, j: C[i, j] +
bias[j].astype(out_dtype), tag=tag.BROADCAST)
return C
+def schedule_matmul_blas_common(outs):
+ """Default matmul schedule for BLAS library"""
+ s = te.create_schedule([x.op for x in outs])
+ te.schedule.AutoInlineInjective(s)
+
+ for out in outs:
+ if "dense" not in out.op.tag and "matmul" not in out.op.tag:
+ schedule_injective_from_existing(s, out)
+ return s
+
+
@autotvm.register_topi_compute("dense_cblas.x86")
def dense_cblas(cfg, data, weight, bias=None, out_dtype=None):
- """Compute dense using a cblas"""
- return dense_blas_common(cfg, data, weight, bias, out_dtype, cblas)
+ """Compute dense using cblas. This is an alias of matmul_nt operator."""
+ return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True,
cblas)
@autotvm.register_topi_schedule("dense_cblas.x86")
def schedule_dense_cblas(_, outs):
- """Create schedule for dense_cblas"""
- return generic.schedule_extern(outs)
+ """Create schedule for dense_cblas. This is an alias of matmul_nt
operator."""
+ return schedule_matmul_blas_common(outs)
@autotvm.register_topi_compute("dense_mkl.x86")
def dense_mkl(cfg, data, weight, bias=None, out_dtype=None):
- """Compute dense using mkl"""
- return dense_blas_common(cfg, data, weight, bias, out_dtype, mkl)
+ """Compute dense using mkl. This is an alias of matmul_nt operator."""
+ return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True,
mkl)
@autotvm.register_topi_schedule("dense_mkl.x86")
def schedule_dense_mkl(_, outs):
- """Create schedule for dense_mkl"""
- # return generic.schedule_extern(outs)
- s = te.create_schedule([x.op for x in outs])
- te.schedule.AutoInlineInjective(s)
-
- def _callback(op):
- if "broadcast" in op.tag or "injective" in op.tag or "elemwise" in
op.tag:
- schedule_injective_from_existing(s, op.output(0))
-
- # traverse_inline(s, outs[0].op, _callback)
- for out in outs:
- if "dense" not in out.op.name:
- schedule_injective_from_existing(s, out)
- return s
+ """Create schedule for dense_mkl. This is an alias of matmul_nt
operator."""
+ return schedule_matmul_blas_common(outs)
@autotvm.register_topi_compute("dense_mkldnn.x86")
def dense_mkldnn(cfg, data, weight, bias=None, out_dtype=None):
- """Compute dense using mkldnn"""
- return dense_blas_common(cfg, data, weight, bias, out_dtype, mkldnn)
+ """Compute dense using mkldnn. This is an alias of matmul_nt operator."""
+ return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True,
mkldnn)
@autotvm.register_topi_schedule("dense_mkldnn.x86")
def schedule_dense_mkldnn(_, outs):
- """Create schedule for dense_mkldnn"""
- return generic.schedule_extern(outs)
+ """Create schedule for dense_mkldnn. This is an alias of matmul_nt
operator."""
+ return schedule_matmul_blas_common(outs)
+
+
[email protected]_topi_compute("matmul_cblas.x86")
+def matmul_cblas(
+ cfg, tensor_a, tensor_b, bias=None, out_dtype=None, transpose_a=False,
transpose_b=False
+):
+ """Compute matmul using cblas."""
+ return matmul_blas_common(
+ cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b,
cblas
+ )
+
+
[email protected]_topi_schedule("matmul_cblas.x86")
+def schedule_matmul_cblas(_, outs):
+ """Create schedule for matmul_cblas."""
+ return schedule_matmul_blas_common(outs)
+
+
[email protected]_topi_compute("matmul_mkl.x86")
+def matmul_mkl(
+ cfg, tensor_a, tensor_b, bias=None, out_dtype=None, transpose_a=False,
transpose_b=False
+):
+ """Compute matmul using mkl."""
+ return matmul_blas_common(
+ cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b, mkl
+ )
+
+
[email protected]_topi_schedule("matmul_mkl.x86")
+def schedule_matmul_mkl(_, outs):
+ """Create schedule for matmul_mkl."""
+ return schedule_matmul_blas_common(outs)
+
+
[email protected]_topi_compute("matmul_mkldnn.x86")
+def matmul_mkldnn(
+ cfg, tensor_a, tensor_b, bias=None, out_dtype=None, transpose_a=False,
transpose_b=False
+):
+ """Compute matmul using mkldnn."""
+ return matmul_blas_common(
+ cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b,
mkldnn
+ )
+
+
[email protected]_topi_schedule("matmul_mkldnn.x86")
+def schedule_matmul_mkldnn(_, outs):
+ """Create schedule for matmul_mkldnn."""
+ return schedule_matmul_blas_common(outs)
diff --git a/rust/tvm/src/ir/relay/attrs/nn.rs
b/rust/tvm/src/ir/relay/attrs/nn.rs
index f0137fa..04320d1 100644
--- a/rust/tvm/src/ir/relay/attrs/nn.rs
+++ b/rust/tvm/src/ir/relay/attrs/nn.rs
@@ -56,6 +56,18 @@ pub struct BiasAddAttrsNode {
#[repr(C)]
#[derive(Object, Debug)]
+#[ref_name = "MatmulAttrs"]
+#[type_key = "relay.attrs.MatmulAttrs"]
+pub struct MatmulAttrsNode {
+ pub base: BaseAttrsNode,
+ pub units: IndexExpr,
+ pub out_dtype: DataType,
+ pub transpose_a: bool,
+ pub transpose_b: bool,
+}
+
+#[repr(C)]
+#[derive(Object, Debug)]
#[ref_name = "DenseAttrs"]
#[type_key = "relay.attrs.DenseAttrs"]
pub struct DenseAttrsNode {
diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h
index 81de4bc..6f4db5a 100644
--- a/src/relay/op/make_op.h
+++ b/src/relay/op/make_op.h
@@ -44,6 +44,9 @@ Expr MakeClip(Expr a, double a_min, double a_max);
Expr MakeConcatenate(Expr data, int axis);
+Expr MakeMatmul(Expr tensor_a, Expr tensor_b, IndexExpr units, DataType
out_dtype, bool transpose_a,
+ bool transpose_b);
+
Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype);
Expr MakeBatchMatmul(Expr lhs, Expr rhs, DataType out_dtype);
diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc
index 489be15..4eaa12b 100644
--- a/src/relay/op/nn/nn.cc
+++ b/src/relay/op/nn/nn.cc
@@ -162,7 +162,39 @@ Useful for
.set_support_level(3)
.add_type_rel("FIFOBuffer", FIFOBufferRel);
-// relay.nn.dense
+// ------------------- relay.nn.matmul
+TVM_REGISTER_NODE_TYPE(MatmulAttrs);
+
+Expr MakeMatmul(Expr tensor_a, Expr tensor_b, IndexExpr units, DataType
out_dtype, bool transpose_a,
+ bool transpose_b) {
+ auto attrs = make_object<MatmulAttrs>();
+ attrs->units = units;
+ attrs->out_dtype = out_dtype;
+ attrs->transpose_a = transpose_a;
+ attrs->transpose_b = transpose_b;
+ static const Op& matmul_op = Op::Get("nn.matmul");
+ return Call(matmul_op, {tensor_a, tensor_b}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op.nn._make.matmul").set_body_typed(MakeMatmul);
+
+RELAY_REGISTER_OP("nn.matmul")
+ .describe(R"code(Applies a linear transformation: :math:`C = A * B`. A & B
can be transposed.
+
+- **tensor_a**: `(x1, x2, ..., xn, input_dim)` or `(x1, x2, ..., input_dim,
xn)`
+- **tensor_b**: `(input_dim, units)` or `(units, input_dim)`
+- **out**: `(x1, x2, ..., xn, units)`.
+
+)code" TVM_ADD_FILELINE)
+ .set_attrs_type<MatmulAttrs>()
+ .set_num_inputs(2)
+ .add_argument("tensor_a", "nD Tensor", "The first input Tensor.")
+ .add_argument("tensor_b", "2D Tensor", "The second input Tensor.")
+ .set_support_level(1)
+ .add_type_rel("Matmul", MatmulRel<MatmulAttrs>);
+// ------------------- relay.nn.matmul
+
+// ------------------- relay.nn.dense
TVM_REGISTER_NODE_TYPE(DenseAttrs);
// Positional relay function to create dense operator used by frontend FFI.
@@ -189,9 +221,10 @@ RELAY_REGISTER_OP("nn.dense")
.add_argument("data", "nD Tensor", "Input data.")
.add_argument("weight", "2D Tensor", "Weight matrix.")
.set_support_level(1)
- .add_type_rel("Dense", DenseRel<DenseAttrs>);
+ .add_type_rel("Dense", MatmulRel<DenseAttrs>);
+// ------------------- relay.nn.dense
-// relay.nn.contrib_dense_pack
+// ------------------- relay.nn.contrib_dense_pack
// Positional relay function to create dense_pack operator used by frontend
FFI.
Expr MakeDensePack(Expr data, Expr weight, IndexExpr units, DataType
out_dtype) {
auto attrs = make_object<DenseAttrs>();
@@ -217,6 +250,7 @@ RELAY_REGISTER_OP("nn.contrib_dense_pack")
.add_argument("weight", "3D Tensor", "Packed weight matrix.")
.set_support_level(10)
.add_type_rel("DensePack", DensePackRel<DenseAttrs>);
+// ------------------- relay.nn.contrib_dense_pack
// relay.leaky_relu
TVM_REGISTER_NODE_TYPE(LeakyReluAttrs);
diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h
index 1ac800f..29f200c 100644
--- a/src/relay/op/nn/nn.h
+++ b/src/relay/op/nn/nn.h
@@ -36,31 +36,44 @@ namespace tvm {
namespace relay {
template <typename AttrType>
-bool DenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
- const TypeReporter& reporter) {
+bool MatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+ const TypeReporter& reporter) {
ICHECK_EQ(types.size(), 3);
- const auto* data = types[0].as<TensorTypeNode>();
- const auto* weight = types[1].as<TensorTypeNode>();
- if (data == nullptr) return false;
+ const auto* tensor_a = types[0].as<TensorTypeNode>();
+ const auto* tensor_b = types[1].as<TensorTypeNode>();
+ if (tensor_a == nullptr) return false;
+ ICHECK(static_cast<int>(tensor_a->shape.size()) != 0);
const AttrType* param = attrs.as<AttrType>();
ICHECK(param != nullptr);
+ // Default set to dense layout
+ bool transpose_a = false;
+ bool transpose_b = true;
+ const auto& mattrs = attrs.as<MatmulAttrs>();
+ if (mattrs != nullptr) {
+ transpose_a = mattrs->transpose_a;
+ transpose_b = mattrs->transpose_b;
+ }
- ICHECK(static_cast<int>(data->shape.size()) != 0);
-
- Array<tvm::PrimExpr> dshape = data->shape;
+ const Array<tvm::PrimExpr>& dshape = tensor_a->shape;
Array<tvm::PrimExpr> oshape = dshape;
+ tvm::PrimExpr reduce = dshape[dshape.size() - 1];
+ if (transpose_a) {
+ reduce = dshape[dshape.size() - 2];
+ oshape.Set((oshape.size() - 2), dshape[oshape.size() - 1]);
+ }
if (param->units.defined()) {
- // validate the weight shape is proper if defined
- // Assign weight type
- Array<IndexExpr> wshape({param->units, dshape[dshape.size() - 1]});
- // It is possible for weight to be nullptr in which case we will use
- // data dtype as the weight dtype. However if weight dtype is explicitly
+ // validate the tensor_b shape is proper if defined
+ // Assign tensor_b type
+ const Array<IndexExpr>& wshape = transpose_b ?
Array<IndexExpr>({param->units, reduce})
+ : Array<IndexExpr>({reduce,
param->units});
+ // It is possible for tensor_b to be nullptr in which case we will use
+ // data dtype as the tensor_b dtype. However if tensor_b dtype is
explicitly
// present we will use that.
- auto weight_dtype = (weight == nullptr ? data->dtype : weight->dtype);
+ auto tensor_b_dtype = (tensor_b == nullptr ? tensor_a->dtype :
tensor_b->dtype);
if (param->auto_scheduler_rewritten_layout.size() == 0) {
// Normal case: assign result to reporter
- reporter->Assign(types[1], TensorType(wshape, weight_dtype));
+ reporter->Assign(types[1], TensorType(wshape, tensor_b_dtype));
} else {
// If the layout is rewritten by auto-scheduler,
// we just forcly apply the layout provided by auto-scheduler and
@@ -69,31 +82,32 @@ bool DenseRel(const Array<Type>& types, int num_inputs,
const Attrs& attrs,
}
oshape.Set((oshape.size() - 1), param->units);
} else {
- if (weight == nullptr) return false;
- Array<tvm::PrimExpr> wshape = weight->shape;
- // When weight's layout has been rewritten, figure it out based on the
+ if (tensor_b == nullptr) return false;
+ const Array<tvm::PrimExpr>& wshape = tensor_b->shape;
+ // When tensor_b's layout has been rewritten, figure it out based on the
// total number of elements and input dimensions.
if (param->auto_scheduler_rewritten_layout.size() != 0) {
- PrimExpr weight_elements = 1;
+ PrimExpr tensor_b_elements = 1;
for (size_t i = 0; i < wshape.size(); i++) {
- weight_elements = weight_elements * wshape[i];
+ tensor_b_elements = tensor_b_elements * wshape[i];
}
- oshape.Set(oshape.size() - 1, weight_elements / dshape[dshape.size() -
1]);
- // Otherwise just pull it out of the weight shape directly.
+ oshape.Set(oshape.size() - 1, tensor_b_elements / dshape[dshape.size() -
1]);
+ // Otherwise just pull it out of the tensor_b shape directly.
} else {
- ICHECK(static_cast<int>(weight->shape.size()) == 2);
- if (!data->shape.back().as<tir::AnyNode>()) {
- ICHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1],
weight->shape[1]))
- << "DenseRel: input dimension doesn't match,"
- << " data shape=" << data->shape << ", weight shape=" <<
weight->shape;
+ ICHECK(static_cast<int>(tensor_b->shape.size()) == 2);
+ if (!tensor_a->shape.back().as<tir::AnyNode>()) {
+ ICHECK((transpose_b && reporter->AssertEQ(reduce, tensor_b->shape[1]))
||
+ (!transpose_b && reporter->AssertEQ(reduce,
tensor_b->shape[0])))
+ << "MatmulRel: input dimension doesn't match,"
+ << " tensor_a shape=" << tensor_a->shape << ", tensor_b shape=" <<
tensor_b->shape;
}
- oshape.Set((oshape.size() - 1), wshape[0]);
+ oshape.Set((oshape.size() - 1), transpose_b ? wshape[0] : wshape[1]);
}
}
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
- out_dtype = data->dtype;
+ out_dtype = tensor_a->dtype;
}
// assign output type
reporter->Assign(types[2], TensorType(oshape, out_dtype));
diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc
index 6284524..592fa77 100644
--- a/src/relay/qnn/op/dense.cc
+++ b/src/relay/qnn/op/dense.cc
@@ -70,7 +70,7 @@ bool QnnDenseRel(const Array<Type>& types, int num_inputs,
const Attrs& attrs,
// Collect the input tensor and output tensor devoid of scale and zero
points to reuse Relay
// Dense infer type function.
Array<Type> tensor_types = {types[0], types[1], types[6]};
- return DenseRel<DenseAttrs>(tensor_types, 3, attrs, reporter);
+ return MatmulRel<DenseAttrs>(tensor_types, 3, attrs, reporter);
}
// Positional relay function to create quantized dense operator used by
frontend FFI.
diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc
b/src/relay/transforms/auto_scheduler_layout_rewrite.cc
index edc4119..da0bd35 100644
--- a/src/relay/transforms/auto_scheduler_layout_rewrite.cc
+++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc
@@ -87,6 +87,8 @@ class FuncMutator : public ExprMutator {
updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout);
} else if (auto pattr = call->attrs.as<Conv3DAttrs>()) {
updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout);
+ } else if (auto pattr = call->attrs.as<MatmulAttrs>()) {
+ updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout);
} else if (auto pattr = call->attrs.as<DenseAttrs>()) {
updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout);
} else if (auto pattr = call->attrs.as<BatchMatmulAttrs>()) {
@@ -103,9 +105,9 @@ class FuncMutator : public ExprMutator {
std::deque<std::string> ori_layouts_queue_;
std::deque<std::string> new_layouts_queue_;
- std::vector<std::string> target_ops_{"nn.conv2d", "nn.conv3d",
-
"nn.contrib_conv2d_winograd_without_weight_transform",
- "nn.dense", "nn.batch_matmul"};
+ std::vector<std::string> target_ops_{
+ "nn.conv2d", "nn.conv3d",
"nn.contrib_conv2d_winograd_without_weight_transform",
+ "nn.matmul", "nn.dense", "nn.batch_matmul"};
};
Expr AutoSchedulerLayoutRewriter::VisitExpr_(const CallNode* n) {
@@ -166,6 +168,8 @@
TVM_REGISTER_GLOBAL("relay.attrs.get_auto_scheduler_rewritten_layout")
return
attrs.as<Conv2DWinogradAttrs>()->auto_scheduler_rewritten_layout;
} else if (attrs->IsInstance<Conv3DAttrs>()) {
return attrs.as<Conv3DAttrs>()->auto_scheduler_rewritten_layout;
+ } else if (attrs->IsInstance<MatmulAttrs>()) {
+ return attrs.as<MatmulAttrs>()->auto_scheduler_rewritten_layout;
} else if (attrs->IsInstance<DenseAttrs>()) {
return attrs.as<DenseAttrs>()->auto_scheduler_rewritten_layout;
} else if (attrs->IsInstance<BatchMatmulAttrs>()) {
diff --git a/tests/python/frontend/tensorflow/test_forward.py
b/tests/python/frontend/tensorflow/test_forward.py
index 136dcab..583014f 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -117,6 +117,7 @@ def run_tvm_graph(
disabled_pass=None,
ignore_in_shape=False,
serialize=False,
+ use_dense_op=True,
):
"""Generic function to compile on relay and execute on tvm"""
input_data = convert_to_list(input_data)
@@ -131,7 +132,11 @@ def run_tvm_graph(
e: i.shape if hasattr(i, "shape") else () for e, i in
zip(input_node, input_data)
}
mod, params = relay.frontend.from_tensorflow(
- graph_def, layout=layout, shape=shape_dict, outputs=out_names
+ graph_def,
+ layout=layout,
+ shape=shape_dict,
+ outputs=out_names,
+ use_dense_op=use_dense_op,
)
dev = tvm.device(target, 0)
if mode == "debug":
@@ -213,6 +218,7 @@ def compare_tf_with_tvm(
add_shapes_to_graph_def=True,
targets=None,
ignore_in_shape=False,
+ use_dense_op=True,
):
"""Generic function to generate and compare tensorflow and TVM output"""
@@ -260,6 +266,7 @@ def compare_tf_with_tvm(
mode=mode,
cuda_layout=cuda_layout,
ignore_in_shape=ignore_in_shape,
+ use_dense_op=use_dense_op,
)
# since the names from tensorflow and relay runs are not exactly
same,
# first len(tf_output) will be compared
@@ -1795,7 +1802,8 @@ def _test_matmul(i, j, k, dtype, outer=None):
A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype)
B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
- compare_tf_with_tvm([A_np, B_np], [A.name, B.name],
result.name)
+ compare_tf_with_tvm([A_np, B_np], [A.name, B.name],
result.name, use_dense_op=True)
+ compare_tf_with_tvm([A_np, B_np], [A.name, B.name],
result.name, use_dense_op=False)
def test_forward_matmul():
diff --git a/tests/python/relay/test_op_grad_level2.py
b/tests/python/relay/test_op_grad_level2.py
index 686fd98..c8a9468 100644
--- a/tests/python/relay/test_op_grad_level2.py
+++ b/tests/python/relay/test_op_grad_level2.py
@@ -199,6 +199,22 @@ def test_dense_grad():
verify_dense_grad((5, 4), (3, 4))
+def verify_matmul_grad(a_shape, b_shape, transpose_a, transpose_b):
+ tensor_a = relay.var("tensor_a", relay.TensorType(a_shape, "float32"))
+ tensor_b = relay.var("tensor_b", relay.TensorType(b_shape, "float32"))
+ fwd_func = relay.Function(
+ [tensor_a, tensor_b],
+ relay.nn.matmul(tensor_a, tensor_b, transpose_a=transpose_a,
transpose_b=transpose_b),
+ )
+ check_grad(fwd_func)
+
+
+def test_matmul_grad():
+ verify_matmul_grad((1, 8), (8, 16), False, False)
+ verify_matmul_grad((4, 1), (4, 3), True, False)
+ verify_matmul_grad((4, 5), (3, 4), True, True)
+
+
def verify_batch_flatten_grad(d_shape):
data = relay.var("data", relay.TensorType(d_shape, "float32"))
fwd_func = relay.Function([data], relay.nn.batch_flatten(data))
@@ -216,4 +232,5 @@ if __name__ == "__main__":
test_global_avg_pool2d_grad()
test_conv2d_grad()
test_dense_grad()
+ test_matmul_grad()
test_batch_flatten_grad()
diff --git a/tests/python/relay/test_op_level1.py
b/tests/python/relay/test_op_level1.py
index 89475ac..cbc3e7f 100644
--- a/tests/python/relay/test_op_level1.py
+++ b/tests/python/relay/test_op_level1.py
@@ -410,6 +410,66 @@ def test_batch_norm():
@pytest.mark.xfail
+def test_matmul_type_check():
+ dtype = "float16"
+ n, c, h, w = 2, 2, 2, 2
+ x = relay.var("x", relay.TensorType((n, c, h, w), dtype))
+ # it should fail since it does not match with m(2)
+ mismatch_w = 3
+ w = relay.var("w", relay.TensorType((mismatch_w, 2), dtype))
+ y = relay.nn.matmul(x, w)
+ yy = run_infer_type(y)
+
+
[email protected]_gpu
+def test_matmul():
+ for dtype in ["float16", "float32"]:
+ # Matmul accuracy for float16 is poor
+ if dtype == "float16":
+ continue
+ n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"),
te.size_var("w")
+ x = relay.var("x", relay.TensorType((n, c, h, w), dtype))
+ w = relay.var("w", relay.TensorType((2, w), dtype))
+ y = relay.nn.matmul(x, w, units=2, transpose_b=True)
+ assert "units=2" in y.astext()
+ yy = run_infer_type(y)
+ assert yy.checked_type == relay.TensorType((n, c, h, 2), dtype)
+
+ n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), 2
+ x = relay.var("x", relay.TensorType((n, c, w, h), dtype))
+ wh, ww = te.size_var("wh"), te.size_var("ww")
+ w = relay.var("w", relay.TensorType((wh, ww), dtype))
+ y = relay.nn.matmul(x, w, transpose_a=True)
+ yy = run_infer_type(y)
+ assert yy.checked_type == relay.TensorType((n, c, h, ww), dtype)
+
+ n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), 2
+ x = relay.var("x", relay.TensorType((n, c, h, w), dtype))
+ w = relay.var("w", relay.IncompleteType())
+ y = relay.nn.matmul(x, w, units=2)
+ yy = run_infer_type(y)
+ assert yy.checked_type == relay.TensorType((n, c, h, 2), dtype)
+
+ x = relay.var("x", shape=(5, 10), dtype=dtype)
+ w = relay.var("w", shape=(5, 2), dtype=dtype)
+ z = relay.nn.matmul(x, w, transpose_a=True)
+
+ # Check result.
+ func = relay.Function([x, w], z)
+ x_data = np.random.rand(5, 10).astype(dtype)
+ w_data = np.random.rand(5, 2).astype(dtype)
+ ref_res = np.dot(x_data.transpose(), w_data)
+
+ for target, dev in tvm.testing.enabled_targets():
+ intrp1 = relay.create_executor("graph", device=dev, target=target)
+ intrp2 = relay.create_executor("debug", device=dev, target=target)
+ op_res1 = intrp1.evaluate(func)(x_data, w_data)
+ tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5)
+ op_res2 = intrp2.evaluate(func)(x_data, w_data)
+ tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=1e-5)
+
+
[email protected]
def test_dense_type_check():
dtype = "float16"
n, c, h, w = 2, 2, 2, 2
@@ -426,7 +486,7 @@ def test_dense():
for dtype in ["float16", "float32"]:
# Dense accuracy for float16 is poor
if dtype == "float16":
- return
+ continue
n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"),
te.size_var("w")
x = relay.var("x", relay.TensorType((n, c, h, w), dtype))
w = relay.var("w", relay.TensorType((2, w), dtype))
@@ -506,6 +566,7 @@ if __name__ == "__main__":
test_log_softmax()
test_dropout()
test_batch_norm()
+ test_matmul()
test_dense()
test_bitserial_dense()
test_dense_dtype()
diff --git a/tests/python/topi/python/test_topi_matmul.py
b/tests/python/topi/python/test_topi_matmul.py
index e5a21a3..de2d4d3 100644
--- a/tests/python/topi/python/test_topi_matmul.py
+++ b/tests/python/topi/python/test_topi_matmul.py
@@ -41,6 +41,31 @@ def with_tvm(lam, *args):
return out_nd.numpy()
+def verify_nn_matmul(sa, sb, transp_a, transp_b):
+ a = np.random.uniform(low=-1.0, high=1.0, size=sa).astype(np.float32)
+ b = np.random.uniform(low=-1.0, high=1.0, size=sb).astype(np.float32)
+ c1 = np.matmul(np.transpose(a) if transp_a else a, np.transpose(b) if
transp_b else b)
+ c2 = with_tvm(
+ lambda A, B: topi.nn.matmul(A, B, transpose_a=transp_a,
transpose_b=transp_b),
+ a,
+ b,
+ )
+ tvm.testing.assert_allclose(c1, c2, rtol=1e-5, atol=1e-5)
+
+
+def test_nn_matmul():
+ verify_nn_matmul((1, 1), (1, 1), False, False)
+ verify_nn_matmul((1, 1), (1, 1), True, True)
+ verify_nn_matmul((2, 2), (2, 2), False, False)
+ verify_nn_matmul((2, 2), (2, 2), True, True)
+ verify_nn_matmul((2, 3), (3, 5), False, False)
+ verify_nn_matmul((5, 3), (3, 2), False, False)
+ verify_nn_matmul((3, 5), (3, 2), True, False)
+ verify_nn_matmul((3, 5), (2, 3), True, True)
+ verify_nn_matmul((3, 5), (3, 2), True, False)
+ verify_nn_matmul((5, 3), (2, 3), False, True)
+
+
def verify_matmul(sa, sb, transp_a, transp_b):
a = np.random.uniform(low=-1.0, high=1.0, size=sa).astype(np.float32)
b = np.random.uniform(low=-1.0, high=1.0, size=sb).astype(np.float32)
@@ -79,5 +104,6 @@ def test_tensordot():
if __name__ == "__main__":
+ test_nn_matmul()
test_matmul()
test_tensordot()