This is an automated email from the ASF dual-hosted git repository.
expye pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5f5d90a3f5 [TOPI] Enhance `topi.nn.matmul` (#16052)
5f5d90a3f5 is described below
commit 5f5d90a3f5e907615c5bc1f0ba72fd8441e73612
Author: Yixin Dong <[email protected]>
AuthorDate: Sun Nov 12 01:23:01 2023 -0800
[TOPI] Enhance `topi.nn.matmul` (#16052)
* 1102
* 1109
---
python/tvm/topi/nn/dense.py | 125 +++++++++++++++------------
tests/python/topi/python/test_topi_matmul.py | 61 +++++++++++--
2 files changed, 124 insertions(+), 62 deletions(-)
diff --git a/python/tvm/topi/nn/dense.py b/python/tvm/topi/nn/dense.py
index ce3aebadb6..d81060fe8b 100644
--- a/python/tvm/topi/nn/dense.py
+++ b/python/tvm/topi/nn/dense.py
@@ -19,7 +19,7 @@
import tvm
from tvm import auto_scheduler, te
-from .. import tag
+from .. import tag, add
def matmul(
@@ -65,86 +65,103 @@ def matmul(
output : tvm.te.Tensor
2-D with shape [batch, out_dim]
"""
- # TODO(jcf94): Add multi-dim support for tensor_a
- assert len(tensor_a.shape) == 2, "only support 2-dim matmul"
+ # TODO(yixin): support cases for 1-dim input
+ # TODO(yixin): adding support and further check for >2-dim input in
autotvm template
+ assert (
+ len(tensor_a.shape) >= 2 and len(tensor_b.shape) >= 2
+ ), "1-dim matmul is not supported yet."
if bias is not None:
assert len(bias.shape) == 1
if out_dtype is None:
out_dtype = tensor_a.dtype
if transpose_a:
- in_dim, batch = tensor_a.shape
+ reduce_dim_a, in_dim = tensor_a.shape[-2:]
else:
- batch, in_dim = tensor_a.shape
+ in_dim, reduce_dim_a = tensor_a.shape[-2:]
+ batch_dims_a = tensor_a.shape[:-2]
if auto_scheduler_rewritten_layout:
# Infer shape for the rewritten layout
- out_dim, red_dim = auto_scheduler.get_shape_from_rewritten_layout(
+ assert len(tensor_b).shape == 2, "only support 2-dim matmul when using
auto-scheduler"
+ out_dim, reduce_dim_b = auto_scheduler.get_shape_from_rewritten_layout(
auto_scheduler_rewritten_layout, ["j", "k"]
)
auto_scheduler.remove_index_check(tensor_b)
elif meta_schedule_original_shape:
auto_scheduler.rewrite_tensor_shape(tensor_b,
meta_schedule_original_shape)
if transpose_b:
- out_dim, red_dim = tensor_b.shape
+ out_dim, reduce_dim_b = tensor_b.shape[-2:]
else:
- red_dim, out_dim = tensor_b.shape
+ reduce_dim_b, out_dim = tensor_b.shape[-2:]
elif transpose_b:
- out_dim, red_dim = tensor_b.shape
+ out_dim, reduce_dim_b = tensor_b.shape[-2:]
else:
- red_dim, out_dim = tensor_b.shape
-
- # cmp should be done by values
- condition = True
- if isinstance(in_dim, tvm.tir.SizeVar): # "any_dim"
- condition = False
- elif isinstance(red_dim, tvm.tir.SizeVar): # "any_dim"
- condition = False
- if condition:
- assert int(in_dim) == int(
- red_dim
- ), "Inner dimensions of dense do not match. {in_dim} vs {red_dim}."
-
- k = te.reduce_axis((0, in_dim), name="k")
- if (transpose_a, transpose_b) == (True, True):
- compute_lambda = lambda i, j: te.sum(
- tensor_a[k, i].astype(out_dtype) * tensor_b[j,
k].astype(out_dtype), axis=k
+ reduce_dim_b, out_dim = tensor_b.shape[-2:]
+ batch_dims_b = tensor_b.shape[:-2]
+
+ if not isinstance(reduce_dim_a, tvm.tir.Var) and not
isinstance(reduce_dim_b, tvm.tir.Var):
+ assert int(reduce_dim_a) == int(
+ reduce_dim_b
+ ), f"Reduction dimensions of dense do not match. {reduce_dim_a} vs
{reduce_dim_b}."
+
+ result_ndim = max(len(batch_dims_a), len(batch_dims_b))
+ batch_dims_a = [1] * (result_ndim - len(batch_dims_a)) + batch_dims_a
+ batch_dims_b = [1] * (result_ndim - len(batch_dims_b)) + batch_dims_b
+
+ for idx, (l, r) in enumerate(zip(batch_dims_a, batch_dims_b)):
+ if (
+ not isinstance(l, tvm.tir.Var)
+ and not isinstance(r, tvm.tir.Var)
+ and int(l) != 1
+ and int(r) != 1
+ ):
+ assert int(l) == int(r), (
+ "Batch dimensions of dense do not match: "
+ f"{tensor_a.shape[:-2]} vs {tensor_b.shape[:-2]}."
+ )
+ if not isinstance(l, tvm.tir.Var) and int(l) == 1:
+ batch_dims_a[idx] = batch_dims_b[idx]
+
+ k = te.reduce_axis((0, reduce_dim_a), name="k")
+
+ def compute(*indices):
+ batch_indices_a = indices[-len(tensor_a.shape) : -2]
+ batch_indices_a = [
+ i if isinstance(dim, tvm.tir.Var) or int(dim) != 1 else 0
+ for i, dim in zip(batch_indices_a, tensor_a.shape[:-2])
+ ]
+ batch_indices_b = indices[-len(tensor_b.shape) : -2]
+ batch_indices_b = [
+ i if isinstance(dim, tvm.tir.Var) or int(dim) != 1 else 0
+ for i, dim in zip(batch_indices_b, tensor_b.shape[:-2])
+ ]
+ i, j = indices[-2:]
+ a_indices = (*batch_indices_a, k, i) if transpose_a else
(*batch_indices_a, i, k)
+ b_indices = (*batch_indices_b, j, k) if transpose_b else
(*batch_indices_b, k, j)
+ return te.sum(
+ tensor_a[a_indices].astype(out_dtype) *
tensor_b[b_indices].astype(out_dtype), axis=k
)
- compute_name = "T_matmul_TT"
- compute_tag = "matmul"
- elif (transpose_a, transpose_b) == (True, False):
- compute_lambda = lambda i, j: te.sum(
- tensor_a[k, i].astype(out_dtype) * tensor_b[k,
j].astype(out_dtype), axis=k
- )
- compute_name = "T_matmul_TN"
- compute_tag = "matmul"
- elif (transpose_a, transpose_b) == (False, True):
- compute_lambda = lambda i, j: te.sum(
- tensor_a[i, k].astype(out_dtype) * tensor_b[j,
k].astype(out_dtype), axis=k
- )
- compute_name = "T_matmul_NT"
- # TODO(jcf94): Remove `dense` when `matmul` is finally ready
- compute_tag = "dense"
- else: # (transpose_a, transpose_b) == (False, False):
- compute_lambda = lambda i, j: te.sum(
- tensor_a[i, k].astype(out_dtype) * tensor_b[k,
j].astype(out_dtype), axis=k
- )
- compute_name = "T_matmul_NN"
- compute_tag = "matmul"
+
+ compute_name = {
+ (True, True): "T_matmul_TT",
+ (True, False): "T_matmul_TN",
+ (False, True): "T_matmul_NT",
+ (False, False): "T_matmul_NN",
+ }[(transpose_a, transpose_b)]
+
+ # TODO(jcf94): Remove `dense` when `matmul` is finally ready
+ compute_tag = "dense" if (transpose_a, transpose_b) == (False, True) else
"matmul"
mat = te.compute(
- (batch, out_dim),
- compute_lambda,
+ (*batch_dims_a, in_dim, out_dim),
+ compute,
name=compute_name,
tag=compute_tag,
attrs={"layout_free_placeholders": [tensor_b]},
)
if bias is not None:
- mat = te.compute(
- (batch, out_dim),
- lambda i, j: mat[i, j] + bias[j].astype(out_dtype),
- tag=tag.BROADCAST,
- )
+ mat = add(mat, bias.astype(out_dtype))
if auto_scheduler_rewritten_layout:
mat = auto_scheduler.rewrite_compute_body(mat,
auto_scheduler_rewritten_layout)
diff --git a/tests/python/topi/python/test_topi_matmul.py
b/tests/python/topi/python/test_topi_matmul.py
index de2d4d3c4c..4b05dd3813 100644
--- a/tests/python/topi/python/test_topi_matmul.py
+++ b/tests/python/topi/python/test_topi_matmul.py
@@ -41,15 +41,40 @@ def with_tvm(lam, *args):
return out_nd.numpy()
-def verify_nn_matmul(sa, sb, transp_a, transp_b):
+def verify_nn_matmul(sa, sb, transp_a, transp_b, bias=False):
a = np.random.uniform(low=-1.0, high=1.0, size=sa).astype(np.float32)
b = np.random.uniform(low=-1.0, high=1.0, size=sb).astype(np.float32)
- c1 = np.matmul(np.transpose(a) if transp_a else a, np.transpose(b) if
transp_b else b)
- c2 = with_tvm(
- lambda A, B: topi.nn.matmul(A, B, transpose_a=transp_a,
transpose_b=transp_b),
- a,
- b,
- )
+ if bias:
+ bias_shape = sb[-2] if transp_b else sb[-1]
+ bias_np = np.random.uniform(low=-1.0, high=1.0,
size=(bias_shape,)).astype(np.float32)
+
+ a_np = a
+ if transp_a:
+ axes = list(range(len(sa)))
+ axes[-2], axes[-1] = axes[-1], axes[-2]
+ a_np = np.transpose(a_np, axes)
+ b_np = b
+ if transp_b:
+ axes = list(range(len(sb)))
+ axes[-2], axes[-1] = axes[-1], axes[-2]
+ b_np = np.transpose(b_np, axes)
+
+ if bias:
+ c1 = np.matmul(a_np, b_np) + bias_np
+ c2 = with_tvm(
+ lambda A, B, bias: topi.nn.matmul(
+ A, B, transpose_a=transp_a, transpose_b=transp_b, bias=bias
+ ),
+ a,
+ b,
+ bias_np,
+ )
+ else:
+ c1 = np.matmul(a_np, b_np)
+ c2 = with_tvm(
+ lambda A, B: topi.nn.matmul(A, B, transpose_a=transp_a,
transpose_b=transp_b), a, b
+ )
+
tvm.testing.assert_allclose(c1, c2, rtol=1e-5, atol=1e-5)
@@ -60,10 +85,30 @@ def test_nn_matmul():
verify_nn_matmul((2, 2), (2, 2), True, True)
verify_nn_matmul((2, 3), (3, 5), False, False)
verify_nn_matmul((5, 3), (3, 2), False, False)
- verify_nn_matmul((3, 5), (3, 2), True, False)
verify_nn_matmul((3, 5), (2, 3), True, True)
verify_nn_matmul((3, 5), (3, 2), True, False)
verify_nn_matmul((5, 3), (2, 3), False, True)
+ # matmul with bias
+ verify_nn_matmul((5, 3), (3, 2), False, False, True)
+ verify_nn_matmul((3, 5), (2, 3), True, True, True)
+ verify_nn_matmul((3, 5), (3, 2), True, False, True)
+ verify_nn_matmul((5, 3), (2, 3), False, True, True)
+ # batched matmul
+ verify_nn_matmul((4, 5, 3), (4, 3, 2), False, False)
+ verify_nn_matmul((4, 3, 5), (4, 2, 3), True, True)
+ verify_nn_matmul((4, 3, 5), (4, 3, 2), True, False)
+ verify_nn_matmul((4, 5, 3), (4, 2, 3), False, True)
+ # batched matmul with broadcast
+ verify_nn_matmul((4, 5, 3), (1, 2, 3), False, True)
+ verify_nn_matmul((1, 5, 3), (4, 2, 3), False, True)
+ verify_nn_matmul((5, 3), (4, 2, 3), False, True)
+ verify_nn_matmul((4, 5, 3), (2, 3), False, True)
+ verify_nn_matmul((2, 4, 5, 3), (1, 2, 3), False, True)
+ # batched matmul with bias
+ verify_nn_matmul((4, 5, 3), (4, 3, 2), False, False, True)
+ verify_nn_matmul((4, 3, 5), (4, 2, 3), True, True, True)
+ verify_nn_matmul((4, 3, 5), (4, 3, 2), True, False, True)
+ verify_nn_matmul((4, 5, 3), (4, 2, 3), False, True, True)
def verify_matmul(sa, sb, transp_a, transp_b):