comaniac commented on a change in pull request #8234:
URL: https://github.com/apache/tvm/pull/8234#discussion_r659979731
##########
File path: python/tvm/relay/op/nn/_nn.py
##########
@@ -1160,21 +1186,46 @@ def batch_flatten_shape_func(attrs, inputs, _):
@script
-def _dense_shape_func(data_shape, weight_shape):
+def _matmul_shape_func(data_shape, weight_shape, data_transposed,
weight_transposed):
Review comment:
nit: the two inputs of matmul are not necessary to be `data` and
`weight`, although it's almost right in DNN. We may consider calling them
`transpose_a` and `transpose_b` as CuDNN.
##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -44,6 +44,10 @@
__all__ = ["from_tensorflow"]
+# By default, TVM convert `tf.matmul` to `nn.dense` op with data tensor
non-transposed and weight
+# tensor transposed
+_USE_DENSE_INSTEAD_OF_MATMUL = True
Review comment:
```suggestion
# The default configurations of Relay TensorFlow frontend.
TF_DEFAULT_CONFIGS = {
# By default, TVM converts `tf.matmul` to `transpose(weight) + nn.dense`,
which introduces
# unnecessary overhead in weight transpose. Change this flag to False to
directly convert to
#`nn.matmul` to get rid of the overhead. However, please note that
`nn.matmul` is in experimental
# so it may have some performance issues.
"use_dense": True,
}
```
I reviewed the primary entrypoint and feel we may need to make it more
general, as we may have other configurations in the future. Also I used the
name "default" to imply that it may be overwritten by user-specified values.
##########
File path: python/tvm/topi/x86/dense.py
##########
@@ -281,72 +281,121 @@ def _callback(op):
return s
-def dense_blas_common(cfg, data, weight, bias, out_dtype, lib):
- """Compute dense using a BLAS library"""
+def matmul_blas_common(cfg, data, weight, bias, out_dtype, data_transposed,
weight_transposed, lib):
+ """Compute matmul/dense using a BLAS library"""
M, K = get_const_tuple(data.shape)
N, _ = get_const_tuple(weight.shape)
if isinstance(M, int) and isinstance(K, int) and isinstance(N, int):
cfg.add_flop(M * K * N * 2)
if data.dtype == "uint8" and weight.dtype == "int8" and out_dtype ==
"int32":
if not hasattr(lib, "matmul_u8s8s32"):
raise NotImplementedError(
- f"Dense with {lib.__name__} for {data.dtype} is not supported "
+ f"Matmul/Dense with {lib.__name__} for {data.dtype} is not
supported "
"(matmulu8s8s32 not imlemented)"
)
- C = lib.matmul_u8s8s32(data, weight, False, True, dtype=out_dtype)
+ C = lib.matmul_u8s8s32(data, weight, data_transposed,
weight_transposed, dtype=out_dtype)
elif data.dtype == "float32" or data.dtype == "float64":
- C = lib.matmul(data, weight, False, True)
+ C = lib.matmul(data, weight, data_transposed, weight_transposed)
else:
- raise NotImplementedError(f"Dense with {lib.__name__} for {data.dtype}
is not supported")
+ raise NotImplementedError(
+ f"Matmul/Dense with {lib.__name__} for {data.dtype} is not
supported"
+ )
if bias is not None:
C = te.compute(C.shape, lambda i, j: C[i, j] +
bias[j].astype(out_dtype), tag=tag.BROADCAST)
return C
+def schedule_matmul_blas_common(outs):
+ """Default matmul schedule for BLAS library"""
+ s = te.create_schedule([x.op for x in outs])
+ te.schedule.AutoInlineInjective(s)
+
+ for out in outs:
+ if "dense" not in out.op.tag and "matmul" not in out.op.tag:
+ schedule_injective_from_existing(s, out)
+ return s
+
+
@autotvm.register_topi_compute("dense_cblas.x86")
def dense_cblas(cfg, data, weight, bias=None, out_dtype=None):
"""Compute dense using a cblas"""
Review comment:
ditto
##########
File path: python/tvm/topi/x86/dense.py
##########
@@ -281,72 +281,121 @@ def _callback(op):
return s
-def dense_blas_common(cfg, data, weight, bias, out_dtype, lib):
- """Compute dense using a BLAS library"""
+def matmul_blas_common(cfg, data, weight, bias, out_dtype, data_transposed,
weight_transposed, lib):
+ """Compute matmul/dense using a BLAS library"""
M, K = get_const_tuple(data.shape)
N, _ = get_const_tuple(weight.shape)
if isinstance(M, int) and isinstance(K, int) and isinstance(N, int):
cfg.add_flop(M * K * N * 2)
if data.dtype == "uint8" and weight.dtype == "int8" and out_dtype ==
"int32":
if not hasattr(lib, "matmul_u8s8s32"):
raise NotImplementedError(
- f"Dense with {lib.__name__} for {data.dtype} is not supported "
+ f"Matmul/Dense with {lib.__name__} for {data.dtype} is not
supported "
"(matmulu8s8s32 not imlemented)"
)
- C = lib.matmul_u8s8s32(data, weight, False, True, dtype=out_dtype)
+ C = lib.matmul_u8s8s32(data, weight, data_transposed,
weight_transposed, dtype=out_dtype)
elif data.dtype == "float32" or data.dtype == "float64":
- C = lib.matmul(data, weight, False, True)
+ C = lib.matmul(data, weight, data_transposed, weight_transposed)
else:
- raise NotImplementedError(f"Dense with {lib.__name__} for {data.dtype}
is not supported")
+ raise NotImplementedError(
+ f"Matmul/Dense with {lib.__name__} for {data.dtype} is not
supported"
+ )
if bias is not None:
C = te.compute(C.shape, lambda i, j: C[i, j] +
bias[j].astype(out_dtype), tag=tag.BROADCAST)
return C
+def schedule_matmul_blas_common(outs):
+ """Default matmul schedule for BLAS library"""
+ s = te.create_schedule([x.op for x in outs])
+ te.schedule.AutoInlineInjective(s)
+
+ for out in outs:
+ if "dense" not in out.op.tag and "matmul" not in out.op.tag:
+ schedule_injective_from_existing(s, out)
+ return s
+
+
@autotvm.register_topi_compute("dense_cblas.x86")
def dense_cblas(cfg, data, weight, bias=None, out_dtype=None):
"""Compute dense using a cblas"""
- return dense_blas_common(cfg, data, weight, bias, out_dtype, cblas)
+ return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True,
cblas)
@autotvm.register_topi_schedule("dense_cblas.x86")
def schedule_dense_cblas(_, outs):
"""Create schedule for dense_cblas"""
- return generic.schedule_extern(outs)
+ return schedule_matmul_blas_common(outs)
@autotvm.register_topi_compute("dense_mkl.x86")
def dense_mkl(cfg, data, weight, bias=None, out_dtype=None):
"""Compute dense using mkl"""
- return dense_blas_common(cfg, data, weight, bias, out_dtype, mkl)
+ return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True,
mkl)
@autotvm.register_topi_schedule("dense_mkl.x86")
def schedule_dense_mkl(_, outs):
"""Create schedule for dense_mkl"""
- # return generic.schedule_extern(outs)
- s = te.create_schedule([x.op for x in outs])
- te.schedule.AutoInlineInjective(s)
-
- def _callback(op):
- if "broadcast" in op.tag or "injective" in op.tag or "elemwise" in
op.tag:
- schedule_injective_from_existing(s, op.output(0))
-
- # traverse_inline(s, outs[0].op, _callback)
- for out in outs:
- if "dense" not in out.op.name:
- schedule_injective_from_existing(s, out)
- return s
+ return schedule_matmul_blas_common(outs)
@autotvm.register_topi_compute("dense_mkldnn.x86")
def dense_mkldnn(cfg, data, weight, bias=None, out_dtype=None):
"""Compute dense using mkldnn"""
Review comment:
ditto
##########
File path: python/tvm/topi/x86/dense.py
##########
@@ -281,72 +281,121 @@ def _callback(op):
return s
-def dense_blas_common(cfg, data, weight, bias, out_dtype, lib):
- """Compute dense using a BLAS library"""
+def matmul_blas_common(cfg, data, weight, bias, out_dtype, data_transposed,
weight_transposed, lib):
+ """Compute matmul/dense using a BLAS library"""
M, K = get_const_tuple(data.shape)
N, _ = get_const_tuple(weight.shape)
if isinstance(M, int) and isinstance(K, int) and isinstance(N, int):
cfg.add_flop(M * K * N * 2)
if data.dtype == "uint8" and weight.dtype == "int8" and out_dtype ==
"int32":
if not hasattr(lib, "matmul_u8s8s32"):
raise NotImplementedError(
- f"Dense with {lib.__name__} for {data.dtype} is not supported "
+ f"Matmul/Dense with {lib.__name__} for {data.dtype} is not
supported "
"(matmulu8s8s32 not imlemented)"
)
- C = lib.matmul_u8s8s32(data, weight, False, True, dtype=out_dtype)
+ C = lib.matmul_u8s8s32(data, weight, data_transposed,
weight_transposed, dtype=out_dtype)
elif data.dtype == "float32" or data.dtype == "float64":
- C = lib.matmul(data, weight, False, True)
+ C = lib.matmul(data, weight, data_transposed, weight_transposed)
else:
- raise NotImplementedError(f"Dense with {lib.__name__} for {data.dtype}
is not supported")
+ raise NotImplementedError(
+ f"Matmul/Dense with {lib.__name__} for {data.dtype} is not
supported"
+ )
if bias is not None:
C = te.compute(C.shape, lambda i, j: C[i, j] +
bias[j].astype(out_dtype), tag=tag.BROADCAST)
return C
+def schedule_matmul_blas_common(outs):
+ """Default matmul schedule for BLAS library"""
+ s = te.create_schedule([x.op for x in outs])
+ te.schedule.AutoInlineInjective(s)
+
+ for out in outs:
+ if "dense" not in out.op.tag and "matmul" not in out.op.tag:
+ schedule_injective_from_existing(s, out)
+ return s
+
+
@autotvm.register_topi_compute("dense_cblas.x86")
def dense_cblas(cfg, data, weight, bias=None, out_dtype=None):
"""Compute dense using a cblas"""
- return dense_blas_common(cfg, data, weight, bias, out_dtype, cblas)
+ return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True,
cblas)
@autotvm.register_topi_schedule("dense_cblas.x86")
def schedule_dense_cblas(_, outs):
"""Create schedule for dense_cblas"""
- return generic.schedule_extern(outs)
+ return schedule_matmul_blas_common(outs)
@autotvm.register_topi_compute("dense_mkl.x86")
def dense_mkl(cfg, data, weight, bias=None, out_dtype=None):
"""Compute dense using mkl"""
- return dense_blas_common(cfg, data, weight, bias, out_dtype, mkl)
+ return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True,
mkl)
@autotvm.register_topi_schedule("dense_mkl.x86")
def schedule_dense_mkl(_, outs):
"""Create schedule for dense_mkl"""
- # return generic.schedule_extern(outs)
- s = te.create_schedule([x.op for x in outs])
- te.schedule.AutoInlineInjective(s)
-
- def _callback(op):
- if "broadcast" in op.tag or "injective" in op.tag or "elemwise" in
op.tag:
- schedule_injective_from_existing(s, op.output(0))
-
- # traverse_inline(s, outs[0].op, _callback)
- for out in outs:
- if "dense" not in out.op.name:
- schedule_injective_from_existing(s, out)
- return s
+ return schedule_matmul_blas_common(outs)
@autotvm.register_topi_compute("dense_mkldnn.x86")
def dense_mkldnn(cfg, data, weight, bias=None, out_dtype=None):
"""Compute dense using mkldnn"""
- return dense_blas_common(cfg, data, weight, bias, out_dtype, mkldnn)
+ return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True,
mkldnn)
@autotvm.register_topi_schedule("dense_mkldnn.x86")
def schedule_dense_mkldnn(_, outs):
"""Create schedule for dense_mkldnn"""
Review comment:
ditto
##########
File path: python/tvm/topi/nn/dense.py
##########
@@ -51,37 +65,120 @@ def dense(data, weight, bias=None, out_dtype=None,
auto_scheduler_rewritten_layo
assert len(bias.shape) == 1
if out_dtype is None:
out_dtype = data.dtype
- batch, in_dim = data.shape
+ if data_transposed:
+ in_dim, batch = data.shape
+ else:
+ batch, in_dim = data.shape
if auto_scheduler_rewritten_layout:
# Infer shape for the rewritten layout
out_dim, red_dim = auto_scheduler.get_shape_from_rewritten_layout(
- auto_scheduler_rewritten_layout, ["j", "k"]
+ auto_scheduler_rewritten_layout, ["j", "k"] if weight_transposed
else ["k", "j"]
)
auto_scheduler.remove_index_check(weight)
- else:
+ elif weight_transposed:
out_dim, red_dim = weight.shape
+ else:
+ red_dim, out_dim = weight.shape
assert in_dim == red_dim
k = te.reduce_axis((0, in_dim), name="k")
- matmul = te.compute(
+ if data_transposed:
+ if weight_transposed:
+ compute_lambda = lambda i, j: te.sum(
+ data[k, i].astype(out_dtype) * weight[j, k].astype(out_dtype),
axis=k
+ )
+ compute_name = "T_matmul_TT"
+ else:
+ compute_lambda = lambda i, j: te.sum(
+ data[k, i].astype(out_dtype) * weight[k, j].astype(out_dtype),
axis=k
+ )
+ compute_name = "T_matmul_TN"
+ compute_tag = "matmul"
+ else:
+ if weight_transposed:
+ compute_lambda = lambda i, j: te.sum(
+ data[i, k].astype(out_dtype) * weight[j, k].astype(out_dtype),
axis=k
+ )
+ compute_name = "T_dense"
+ compute_tag = "dense"
+ else:
+ compute_lambda = lambda i, j: te.sum(
+ data[i, k].astype(out_dtype) * weight[k, j].astype(out_dtype),
axis=k
+ )
+ compute_name = "T_matmul"
+ compute_tag = "matmul"
+
+ mat = te.compute(
(batch, out_dim),
- lambda i, j: te.sum(data[i, k].astype(out_dtype) * weight[j,
k].astype(out_dtype), axis=k),
- name="T_dense",
- tag="dense",
+ compute_lambda,
+ name=compute_name,
+ tag=compute_tag,
attrs={"layout_free_placeholders": [weight]},
)
+
if bias is not None:
- matmul = te.compute(
+ mat = te.compute(
(batch, out_dim),
- lambda i, j: matmul[i, j] + bias[j].astype(out_dtype),
+ lambda i, j: mat[i, j] + bias[j].astype(out_dtype),
tag=tag.BROADCAST,
)
if auto_scheduler_rewritten_layout:
- matmul = auto_scheduler.rewrite_compute_body(matmul,
auto_scheduler_rewritten_layout)
+ mat = auto_scheduler.rewrite_compute_body(mat,
auto_scheduler_rewritten_layout)
+
+ return mat
+
- return matmul
[email protected]_func
+def matmul_legalize(attrs, inputs, types):
+ """Legalizes matmul op.
+
+ Parameters
+ ----------
+ attrs : tvm.ir.Attrs
+ Attributes of current matmul
+ inputs : list of tvm.relay.Expr
+ The args of the Relay expr to be legalized
+ types : list of types
+ List of input and output types
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The legalized expr
+ """
+ # not to change by default
+ # pylint: disable=unused-argument
+ return None
+
+
+def dense(data, weight, bias=None, out_dtype=None,
auto_scheduler_rewritten_layout=""):
+ """The default implementation of dense in topi.
Review comment:
```suggestion
"""The default implementation of dense in topi. It is identical to
matmul_nt.
```
##########
File path: python/tvm/topi/cuda/dense.py
##########
@@ -77,6 +111,36 @@ def _callback(op):
return s
[email protected]_topi_compute("matmul_default.cuda")
+def matmul_default_cuda(
+ cfg,
+ data,
+ weight,
+ bias=None,
+ out_dtype=None,
+ data_transposed=False,
+ weight_transposed=False,
+):
+ """Matmul operator on cuda"""
Review comment:
ditto
##########
File path: python/tvm/relay/op/strategy/cuda.py
##########
@@ -698,6 +698,36 @@ def conv1d_transpose_strategy_cuda(attrs, inputs,
out_type, target):
return strategy
+@matmul_strategy.register(["cuda", "gpu"])
+def matmul_strategy_cuda(attrs, inputs, out_type, target):
+ """Matmul cuda strategy"""
+ strategy = _op.OpStrategy()
+
+ if is_auto_scheduler_enabled():
+ strategy.add_implementation(
+ wrap_compute_matmul(topi.nn.matmul),
+ naive_schedule,
+ name="matmul.cuda",
+ )
+ else:
+ logger.warning("Matmul other than NT format is not optimized for
cuda.")
Review comment:
- People may not know what "NT" is, as we didn't use this term in matmul
op.
- We may recommend using cublas in this case.
##########
File path: python/tvm/relay/op/strategy/x86.py
##########
@@ -370,6 +370,55 @@ def conv1d_strategy_cpu(attrs, inputs, out_type, target):
return strategy
+@matmul_strategy.register("cpu")
+def matmul_strategy_cpu(attrs, inputs, out_type, target):
+ """matmul x86 strategy"""
+ strategy = _op.OpStrategy()
+ if is_auto_scheduler_enabled():
+ strategy.add_implementation(
+ wrap_compute_matmul(topi.nn.matmul,
need_auto_scheduler_layout=True),
+ naive_schedule,
+ name="matmul.generic",
+ plevel=11,
+ )
+ else:
+ logger.warning("Matmul other than NT format is not optimized for x86.")
Review comment:
ditto.
##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -1230,6 +1239,9 @@ def from_tensorflow(graph, layout="NHWC", shape=None,
outputs=None):
params : dict of str to tvm.nd.NDArray
Dict of converted parameters stored in tvm.nd.NDArray format
"""
+ global _USE_DENSE_INSTEAD_OF_MATMUL
Review comment:
I feel this is confusing too. If `_USE_DENSE_INSTEAD_OF_MATMUL` is not
supposed to be changed by users directly, we should improve the comments of
this global variable. Please see my comment at the global variable.
btw, in this case we can simply `_USE_DENSE_INSTEAD_OF_MATMUL =
use_dense_op` without checking if they are the same or not.
##########
File path: python/tvm/topi/x86/dense.py
##########
@@ -281,72 +281,121 @@ def _callback(op):
return s
-def dense_blas_common(cfg, data, weight, bias, out_dtype, lib):
- """Compute dense using a BLAS library"""
+def matmul_blas_common(cfg, data, weight, bias, out_dtype, data_transposed,
weight_transposed, lib):
+ """Compute matmul/dense using a BLAS library"""
M, K = get_const_tuple(data.shape)
N, _ = get_const_tuple(weight.shape)
if isinstance(M, int) and isinstance(K, int) and isinstance(N, int):
cfg.add_flop(M * K * N * 2)
if data.dtype == "uint8" and weight.dtype == "int8" and out_dtype ==
"int32":
if not hasattr(lib, "matmul_u8s8s32"):
raise NotImplementedError(
- f"Dense with {lib.__name__} for {data.dtype} is not supported "
+ f"Matmul/Dense with {lib.__name__} for {data.dtype} is not
supported "
"(matmulu8s8s32 not imlemented)"
)
- C = lib.matmul_u8s8s32(data, weight, False, True, dtype=out_dtype)
+ C = lib.matmul_u8s8s32(data, weight, data_transposed,
weight_transposed, dtype=out_dtype)
elif data.dtype == "float32" or data.dtype == "float64":
- C = lib.matmul(data, weight, False, True)
+ C = lib.matmul(data, weight, data_transposed, weight_transposed)
else:
- raise NotImplementedError(f"Dense with {lib.__name__} for {data.dtype}
is not supported")
+ raise NotImplementedError(
+ f"Matmul/Dense with {lib.__name__} for {data.dtype} is not
supported"
+ )
if bias is not None:
C = te.compute(C.shape, lambda i, j: C[i, j] +
bias[j].astype(out_dtype), tag=tag.BROADCAST)
return C
+def schedule_matmul_blas_common(outs):
+ """Default matmul schedule for BLAS library"""
+ s = te.create_schedule([x.op for x in outs])
+ te.schedule.AutoInlineInjective(s)
+
+ for out in outs:
+ if "dense" not in out.op.tag and "matmul" not in out.op.tag:
+ schedule_injective_from_existing(s, out)
+ return s
+
+
@autotvm.register_topi_compute("dense_cblas.x86")
def dense_cblas(cfg, data, weight, bias=None, out_dtype=None):
"""Compute dense using a cblas"""
- return dense_blas_common(cfg, data, weight, bias, out_dtype, cblas)
+ return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True,
cblas)
@autotvm.register_topi_schedule("dense_cblas.x86")
def schedule_dense_cblas(_, outs):
"""Create schedule for dense_cblas"""
- return generic.schedule_extern(outs)
+ return schedule_matmul_blas_common(outs)
@autotvm.register_topi_compute("dense_mkl.x86")
def dense_mkl(cfg, data, weight, bias=None, out_dtype=None):
"""Compute dense using mkl"""
Review comment:
ditto
##########
File path: python/tvm/topi/cuda/dense.py
##########
@@ -51,6 +57,34 @@ def dense_cublas(cfg, data, weight, bias=None,
out_dtype=None):
return matmul
[email protected]_topi_compute("matmul_cublas.cuda")
+def matmul_cublas(
+ cfg,
+ data,
+ weight,
+ bias=None,
+ out_dtype=None,
+ data_transposed=False,
+ weight_transposed=False,
+):
+ """Matmul operator on CUDA with CUBLAS"""
+ return _matmul_cublas_common(
+ cfg, data, weight, bias, out_dtype, data_transposed, weight_transposed
+ )
+
+
[email protected]_topi_schedule("matmul_cublas.cuda")
+def schedule_matmul_cublas(_, outs):
+ """Schedule matmul operator using CUBLAS"""
+ return generic.schedule_extern(outs)
+
+
[email protected]_topi_compute("dense_cublas.cuda")
+def dense_cublas(cfg, data, weight, bias=None, out_dtype=None):
+ """Dense operator on CUDA with CUBLAS"""
Review comment:
```suggestion
"""Dense operator on CUDA with CUBLAS. This is an alias of matmul_nt
operator"""
```
Please add a note to all other similar places saying the dense
implementation is indeitcal to matmul_nt.
##########
File path: python/tvm/topi/x86/dense.py
##########
@@ -281,72 +281,121 @@ def _callback(op):
return s
-def dense_blas_common(cfg, data, weight, bias, out_dtype, lib):
- """Compute dense using a BLAS library"""
+def matmul_blas_common(cfg, data, weight, bias, out_dtype, data_transposed,
weight_transposed, lib):
+ """Compute matmul/dense using a BLAS library"""
M, K = get_const_tuple(data.shape)
N, _ = get_const_tuple(weight.shape)
if isinstance(M, int) and isinstance(K, int) and isinstance(N, int):
cfg.add_flop(M * K * N * 2)
if data.dtype == "uint8" and weight.dtype == "int8" and out_dtype ==
"int32":
if not hasattr(lib, "matmul_u8s8s32"):
raise NotImplementedError(
- f"Dense with {lib.__name__} for {data.dtype} is not supported "
+ f"Matmul/Dense with {lib.__name__} for {data.dtype} is not
supported "
"(matmulu8s8s32 not imlemented)"
)
- C = lib.matmul_u8s8s32(data, weight, False, True, dtype=out_dtype)
+ C = lib.matmul_u8s8s32(data, weight, data_transposed,
weight_transposed, dtype=out_dtype)
elif data.dtype == "float32" or data.dtype == "float64":
- C = lib.matmul(data, weight, False, True)
+ C = lib.matmul(data, weight, data_transposed, weight_transposed)
else:
- raise NotImplementedError(f"Dense with {lib.__name__} for {data.dtype}
is not supported")
+ raise NotImplementedError(
+ f"Matmul/Dense with {lib.__name__} for {data.dtype} is not
supported"
+ )
if bias is not None:
C = te.compute(C.shape, lambda i, j: C[i, j] +
bias[j].astype(out_dtype), tag=tag.BROADCAST)
return C
+def schedule_matmul_blas_common(outs):
+ """Default matmul schedule for BLAS library"""
+ s = te.create_schedule([x.op for x in outs])
+ te.schedule.AutoInlineInjective(s)
+
+ for out in outs:
+ if "dense" not in out.op.tag and "matmul" not in out.op.tag:
+ schedule_injective_from_existing(s, out)
+ return s
+
+
@autotvm.register_topi_compute("dense_cblas.x86")
def dense_cblas(cfg, data, weight, bias=None, out_dtype=None):
"""Compute dense using a cblas"""
- return dense_blas_common(cfg, data, weight, bias, out_dtype, cblas)
+ return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True,
cblas)
@autotvm.register_topi_schedule("dense_cblas.x86")
def schedule_dense_cblas(_, outs):
"""Create schedule for dense_cblas"""
Review comment:
ditto
##########
File path: python/tvm/relay/op/nn/nn.py
##########
@@ -1471,6 +1471,47 @@ def bias_add(data, bias, axis=1):
return _make.bias_add(data, bias, axis)
+def matmul(data, weight, units=None, out_dtype="", data_transposed=False,
weight_transposed=False):
+ """Matmul operator.
+ Applies a linear transformation. The X & W can be transposed.
+
+ .. math::
+
+ `Y = X * W`
+
+ Parameters
+ ----------
+ data : tvm.relay.Expr
+ The input data to the operator,
+ of shape `(d_1, d_2, ..., d_n, units_in)` or `(d_1, d_2, ...,
units_in, d_n)`.
+
+ weight : tvm.relay.Expr
+ The weight expressions, 2-D matrix,
+ of shape `(units_in, units)` or `(units, units_in)`.
+
+ units : int, optional
+ Number of hidden units of the matmul transformation.
+
+ out_dtype : str, optional
+ Specifies the output data type for mixed precision dense,
+ of shape `(d_1, d_2, ..., d_n, units)`.
+
+ data_transposed : bool, optional
+ Whether the data tensor is in transposed format.
+
+ weight_transposed : bool, optional
+ Whether the weight tensor is in transposed format.
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The computed result.
+ """
+ if not data_transposed and weight_transposed:
+ return dense(data, weight, units, out_dtype)
Review comment:
Add a TODO saying that this will be removed once dense has been
simplified?
##########
File path: python/tvm/relay/op/strategy/x86.py
##########
@@ -370,6 +370,55 @@ def conv1d_strategy_cpu(attrs, inputs, out_type, target):
return strategy
+@matmul_strategy.register("cpu")
+def matmul_strategy_cpu(attrs, inputs, out_type, target):
+ """matmul x86 strategy"""
+ strategy = _op.OpStrategy()
+ if is_auto_scheduler_enabled():
+ strategy.add_implementation(
+ wrap_compute_matmul(topi.nn.matmul,
need_auto_scheduler_layout=True),
+ naive_schedule,
+ name="matmul.generic",
+ plevel=11,
+ )
+ else:
+ logger.warning("Matmul other than NT format is not optimized for x86.")
+ strategy.add_implementation(
+ wrap_compute_matmul(topi.nn.matmul),
+ naive_schedule,
+ name="matmul.generic",
+ )
+
+ same_type = inputs[0].dtype == inputs[1].dtype == out_type.dtype
+ dtype = inputs[0].dtype
+ u8s8s32 = dtype == "uint8" and inputs[1].dtype == "int8" and
out_type.dtype == "int32"
+ if "cblas" in target.libs:
Review comment:
IMHO, if users specified kernel library in their target, but we cannot
use the library in this matmul, it might be better to throw a warning.
##########
File path: python/tvm/topi/nn/dense.py
##########
@@ -51,37 +65,120 @@ def dense(data, weight, bias=None, out_dtype=None,
auto_scheduler_rewritten_layo
assert len(bias.shape) == 1
if out_dtype is None:
out_dtype = data.dtype
- batch, in_dim = data.shape
+ if data_transposed:
+ in_dim, batch = data.shape
+ else:
+ batch, in_dim = data.shape
if auto_scheduler_rewritten_layout:
# Infer shape for the rewritten layout
out_dim, red_dim = auto_scheduler.get_shape_from_rewritten_layout(
- auto_scheduler_rewritten_layout, ["j", "k"]
+ auto_scheduler_rewritten_layout, ["j", "k"] if weight_transposed
else ["k", "j"]
)
auto_scheduler.remove_index_check(weight)
- else:
+ elif weight_transposed:
out_dim, red_dim = weight.shape
+ else:
+ red_dim, out_dim = weight.shape
assert in_dim == red_dim
k = te.reduce_axis((0, in_dim), name="k")
- matmul = te.compute(
+ if data_transposed:
+ if weight_transposed:
+ compute_lambda = lambda i, j: te.sum(
+ data[k, i].astype(out_dtype) * weight[j, k].astype(out_dtype),
axis=k
+ )
+ compute_name = "T_matmul_TT"
+ else:
+ compute_lambda = lambda i, j: te.sum(
+ data[k, i].astype(out_dtype) * weight[k, j].astype(out_dtype),
axis=k
+ )
+ compute_name = "T_matmul_TN"
+ compute_tag = "matmul"
+ else:
+ if weight_transposed:
+ compute_lambda = lambda i, j: te.sum(
+ data[i, k].astype(out_dtype) * weight[j, k].astype(out_dtype),
axis=k
+ )
+ compute_name = "T_dense"
Review comment:
@jcf94 please discuss as I have the same issue above.
##########
File path: python/tvm/topi/x86/dense.py
##########
@@ -281,72 +281,121 @@ def _callback(op):
return s
-def dense_blas_common(cfg, data, weight, bias, out_dtype, lib):
- """Compute dense using a BLAS library"""
+def matmul_blas_common(cfg, data, weight, bias, out_dtype, data_transposed,
weight_transposed, lib):
+ """Compute matmul/dense using a BLAS library"""
M, K = get_const_tuple(data.shape)
N, _ = get_const_tuple(weight.shape)
if isinstance(M, int) and isinstance(K, int) and isinstance(N, int):
cfg.add_flop(M * K * N * 2)
if data.dtype == "uint8" and weight.dtype == "int8" and out_dtype ==
"int32":
if not hasattr(lib, "matmul_u8s8s32"):
raise NotImplementedError(
- f"Dense with {lib.__name__} for {data.dtype} is not supported "
+ f"Matmul/Dense with {lib.__name__} for {data.dtype} is not
supported "
"(matmulu8s8s32 not imlemented)"
)
- C = lib.matmul_u8s8s32(data, weight, False, True, dtype=out_dtype)
+ C = lib.matmul_u8s8s32(data, weight, data_transposed,
weight_transposed, dtype=out_dtype)
elif data.dtype == "float32" or data.dtype == "float64":
- C = lib.matmul(data, weight, False, True)
+ C = lib.matmul(data, weight, data_transposed, weight_transposed)
else:
- raise NotImplementedError(f"Dense with {lib.__name__} for {data.dtype}
is not supported")
+ raise NotImplementedError(
+ f"Matmul/Dense with {lib.__name__} for {data.dtype} is not
supported"
+ )
if bias is not None:
C = te.compute(C.shape, lambda i, j: C[i, j] +
bias[j].astype(out_dtype), tag=tag.BROADCAST)
return C
+def schedule_matmul_blas_common(outs):
+ """Default matmul schedule for BLAS library"""
+ s = te.create_schedule([x.op for x in outs])
+ te.schedule.AutoInlineInjective(s)
+
+ for out in outs:
+ if "dense" not in out.op.tag and "matmul" not in out.op.tag:
+ schedule_injective_from_existing(s, out)
+ return s
+
+
@autotvm.register_topi_compute("dense_cblas.x86")
def dense_cblas(cfg, data, weight, bias=None, out_dtype=None):
"""Compute dense using a cblas"""
- return dense_blas_common(cfg, data, weight, bias, out_dtype, cblas)
+ return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True,
cblas)
@autotvm.register_topi_schedule("dense_cblas.x86")
def schedule_dense_cblas(_, outs):
"""Create schedule for dense_cblas"""
- return generic.schedule_extern(outs)
+ return schedule_matmul_blas_common(outs)
@autotvm.register_topi_compute("dense_mkl.x86")
def dense_mkl(cfg, data, weight, bias=None, out_dtype=None):
"""Compute dense using mkl"""
- return dense_blas_common(cfg, data, weight, bias, out_dtype, mkl)
+ return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True,
mkl)
@autotvm.register_topi_schedule("dense_mkl.x86")
def schedule_dense_mkl(_, outs):
"""Create schedule for dense_mkl"""
Review comment:
ditto
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]