This is an automated email from the ASF dual-hosted git repository.

expye pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5f5d90a3f5 [TOPI] Enhance `topi.nn.matmul` (#16052)
5f5d90a3f5 is described below

commit 5f5d90a3f5e907615c5bc1f0ba72fd8441e73612
Author: Yixin Dong <[email protected]>
AuthorDate: Sun Nov 12 01:23:01 2023 -0800

    [TOPI] Enhance `topi.nn.matmul` (#16052)
    
    * 1102
    
    * 1109
---
 python/tvm/topi/nn/dense.py                  | 125 +++++++++++++++------------
 tests/python/topi/python/test_topi_matmul.py |  61 +++++++++++--
 2 files changed, 124 insertions(+), 62 deletions(-)

diff --git a/python/tvm/topi/nn/dense.py b/python/tvm/topi/nn/dense.py
index ce3aebadb6..d81060fe8b 100644
--- a/python/tvm/topi/nn/dense.py
+++ b/python/tvm/topi/nn/dense.py
@@ -19,7 +19,7 @@
 import tvm
 from tvm import auto_scheduler, te
 
-from .. import tag
+from .. import tag, add
 
 
 def matmul(
@@ -65,86 +65,103 @@ def matmul(
     output : tvm.te.Tensor
         2-D with shape [batch, out_dim]
     """
-    # TODO(jcf94): Add multi-dim support for tensor_a
-    assert len(tensor_a.shape) == 2, "only support 2-dim matmul"
+    # TODO(yixin): support cases for 1-dim input
+    # TODO(yixin): adding support and further check for >2-dim input in 
autotvm template
+    assert (
+        len(tensor_a.shape) >= 2 and len(tensor_b.shape) >= 2
+    ), "1-dim matmul is not supported yet."
     if bias is not None:
         assert len(bias.shape) == 1
     if out_dtype is None:
         out_dtype = tensor_a.dtype
     if transpose_a:
-        in_dim, batch = tensor_a.shape
+        reduce_dim_a, in_dim = tensor_a.shape[-2:]
     else:
-        batch, in_dim = tensor_a.shape
+        in_dim, reduce_dim_a = tensor_a.shape[-2:]
+    batch_dims_a = tensor_a.shape[:-2]
 
     if auto_scheduler_rewritten_layout:
         # Infer shape for the rewritten layout
-        out_dim, red_dim = auto_scheduler.get_shape_from_rewritten_layout(
+        assert len(tensor_b).shape == 2, "only support 2-dim matmul when using 
auto-scheduler"
+        out_dim, reduce_dim_b = auto_scheduler.get_shape_from_rewritten_layout(
             auto_scheduler_rewritten_layout, ["j", "k"]
         )
         auto_scheduler.remove_index_check(tensor_b)
     elif meta_schedule_original_shape:
         auto_scheduler.rewrite_tensor_shape(tensor_b, 
meta_schedule_original_shape)
         if transpose_b:
-            out_dim, red_dim = tensor_b.shape
+            out_dim, reduce_dim_b = tensor_b.shape[-2:]
         else:
-            red_dim, out_dim = tensor_b.shape
+            reduce_dim_b, out_dim = tensor_b.shape[-2:]
     elif transpose_b:
-        out_dim, red_dim = tensor_b.shape
+        out_dim, reduce_dim_b = tensor_b.shape[-2:]
     else:
-        red_dim, out_dim = tensor_b.shape
-
-    # cmp should be done by values
-    condition = True
-    if isinstance(in_dim, tvm.tir.SizeVar):  # "any_dim"
-        condition = False
-    elif isinstance(red_dim, tvm.tir.SizeVar):  # "any_dim"
-        condition = False
-    if condition:
-        assert int(in_dim) == int(
-            red_dim
-        ), "Inner dimensions of dense do not match. {in_dim} vs {red_dim}."
-
-    k = te.reduce_axis((0, in_dim), name="k")
-    if (transpose_a, transpose_b) == (True, True):
-        compute_lambda = lambda i, j: te.sum(
-            tensor_a[k, i].astype(out_dtype) * tensor_b[j, 
k].astype(out_dtype), axis=k
+        reduce_dim_b, out_dim = tensor_b.shape[-2:]
+    batch_dims_b = tensor_b.shape[:-2]
+
+    if not isinstance(reduce_dim_a, tvm.tir.Var) and not 
isinstance(reduce_dim_b, tvm.tir.Var):
+        assert int(reduce_dim_a) == int(
+            reduce_dim_b
+        ), f"Reduction dimensions of dense do not match. {reduce_dim_a} vs 
{reduce_dim_b}."
+
+    result_ndim = max(len(batch_dims_a), len(batch_dims_b))
+    batch_dims_a = [1] * (result_ndim - len(batch_dims_a)) + batch_dims_a
+    batch_dims_b = [1] * (result_ndim - len(batch_dims_b)) + batch_dims_b
+
+    for idx, (l, r) in enumerate(zip(batch_dims_a, batch_dims_b)):
+        if (
+            not isinstance(l, tvm.tir.Var)
+            and not isinstance(r, tvm.tir.Var)
+            and int(l) != 1
+            and int(r) != 1
+        ):
+            assert int(l) == int(r), (
+                "Batch dimensions of dense do not match: "
+                f"{tensor_a.shape[:-2]} vs {tensor_b.shape[:-2]}."
+            )
+        if not isinstance(l, tvm.tir.Var) and int(l) == 1:
+            batch_dims_a[idx] = batch_dims_b[idx]
+
+    k = te.reduce_axis((0, reduce_dim_a), name="k")
+
+    def compute(*indices):
+        batch_indices_a = indices[-len(tensor_a.shape) : -2]
+        batch_indices_a = [
+            i if isinstance(dim, tvm.tir.Var) or int(dim) != 1 else 0
+            for i, dim in zip(batch_indices_a, tensor_a.shape[:-2])
+        ]
+        batch_indices_b = indices[-len(tensor_b.shape) : -2]
+        batch_indices_b = [
+            i if isinstance(dim, tvm.tir.Var) or int(dim) != 1 else 0
+            for i, dim in zip(batch_indices_b, tensor_b.shape[:-2])
+        ]
+        i, j = indices[-2:]
+        a_indices = (*batch_indices_a, k, i) if transpose_a else 
(*batch_indices_a, i, k)
+        b_indices = (*batch_indices_b, j, k) if transpose_b else 
(*batch_indices_b, k, j)
+        return te.sum(
+            tensor_a[a_indices].astype(out_dtype) * 
tensor_b[b_indices].astype(out_dtype), axis=k
         )
-        compute_name = "T_matmul_TT"
-        compute_tag = "matmul"
-    elif (transpose_a, transpose_b) == (True, False):
-        compute_lambda = lambda i, j: te.sum(
-            tensor_a[k, i].astype(out_dtype) * tensor_b[k, 
j].astype(out_dtype), axis=k
-        )
-        compute_name = "T_matmul_TN"
-        compute_tag = "matmul"
-    elif (transpose_a, transpose_b) == (False, True):
-        compute_lambda = lambda i, j: te.sum(
-            tensor_a[i, k].astype(out_dtype) * tensor_b[j, 
k].astype(out_dtype), axis=k
-        )
-        compute_name = "T_matmul_NT"
-        # TODO(jcf94): Remove `dense` when `matmul` is finally ready
-        compute_tag = "dense"
-    else:  # (transpose_a, transpose_b) == (False, False):
-        compute_lambda = lambda i, j: te.sum(
-            tensor_a[i, k].astype(out_dtype) * tensor_b[k, 
j].astype(out_dtype), axis=k
-        )
-        compute_name = "T_matmul_NN"
-        compute_tag = "matmul"
+
+    compute_name = {
+        (True, True): "T_matmul_TT",
+        (True, False): "T_matmul_TN",
+        (False, True): "T_matmul_NT",
+        (False, False): "T_matmul_NN",
+    }[(transpose_a, transpose_b)]
+
+    # TODO(jcf94): Remove `dense` when `matmul` is finally ready
+    compute_tag = "dense" if (transpose_a, transpose_b) == (False, True) else 
"matmul"
 
     mat = te.compute(
-        (batch, out_dim),
-        compute_lambda,
+        (*batch_dims_a, in_dim, out_dim),
+        compute,
         name=compute_name,
         tag=compute_tag,
         attrs={"layout_free_placeholders": [tensor_b]},
     )
 
     if bias is not None:
-        mat = te.compute(
-            (batch, out_dim),
-            lambda i, j: mat[i, j] + bias[j].astype(out_dtype),
-            tag=tag.BROADCAST,
-        )
+        mat = add(mat, bias.astype(out_dtype))
 
     if auto_scheduler_rewritten_layout:
         mat = auto_scheduler.rewrite_compute_body(mat, 
auto_scheduler_rewritten_layout)
diff --git a/tests/python/topi/python/test_topi_matmul.py 
b/tests/python/topi/python/test_topi_matmul.py
index de2d4d3c4c..4b05dd3813 100644
--- a/tests/python/topi/python/test_topi_matmul.py
+++ b/tests/python/topi/python/test_topi_matmul.py
@@ -41,15 +41,40 @@ def with_tvm(lam, *args):
     return out_nd.numpy()
 
 
-def verify_nn_matmul(sa, sb, transp_a, transp_b):
+def verify_nn_matmul(sa, sb, transp_a, transp_b, bias=False):
     a = np.random.uniform(low=-1.0, high=1.0, size=sa).astype(np.float32)
     b = np.random.uniform(low=-1.0, high=1.0, size=sb).astype(np.float32)
-    c1 = np.matmul(np.transpose(a) if transp_a else a, np.transpose(b) if 
transp_b else b)
-    c2 = with_tvm(
-        lambda A, B: topi.nn.matmul(A, B, transpose_a=transp_a, 
transpose_b=transp_b),
-        a,
-        b,
-    )
+    if bias:
+        bias_shape = sb[-2] if transp_b else sb[-1]
+        bias_np = np.random.uniform(low=-1.0, high=1.0, 
size=(bias_shape,)).astype(np.float32)
+
+    a_np = a
+    if transp_a:
+        axes = list(range(len(sa)))
+        axes[-2], axes[-1] = axes[-1], axes[-2]
+        a_np = np.transpose(a_np, axes)
+    b_np = b
+    if transp_b:
+        axes = list(range(len(sb)))
+        axes[-2], axes[-1] = axes[-1], axes[-2]
+        b_np = np.transpose(b_np, axes)
+
+    if bias:
+        c1 = np.matmul(a_np, b_np) + bias_np
+        c2 = with_tvm(
+            lambda A, B, bias: topi.nn.matmul(
+                A, B, transpose_a=transp_a, transpose_b=transp_b, bias=bias
+            ),
+            a,
+            b,
+            bias_np,
+        )
+    else:
+        c1 = np.matmul(a_np, b_np)
+        c2 = with_tvm(
+            lambda A, B: topi.nn.matmul(A, B, transpose_a=transp_a, 
transpose_b=transp_b), a, b
+        )
+
     tvm.testing.assert_allclose(c1, c2, rtol=1e-5, atol=1e-5)
 
 
@@ -60,10 +85,30 @@ def test_nn_matmul():
     verify_nn_matmul((2, 2), (2, 2), True, True)
     verify_nn_matmul((2, 3), (3, 5), False, False)
     verify_nn_matmul((5, 3), (3, 2), False, False)
-    verify_nn_matmul((3, 5), (3, 2), True, False)
     verify_nn_matmul((3, 5), (2, 3), True, True)
     verify_nn_matmul((3, 5), (3, 2), True, False)
     verify_nn_matmul((5, 3), (2, 3), False, True)
+    # matmul with bias
+    verify_nn_matmul((5, 3), (3, 2), False, False, True)
+    verify_nn_matmul((3, 5), (2, 3), True, True, True)
+    verify_nn_matmul((3, 5), (3, 2), True, False, True)
+    verify_nn_matmul((5, 3), (2, 3), False, True, True)
+    # batched matmul
+    verify_nn_matmul((4, 5, 3), (4, 3, 2), False, False)
+    verify_nn_matmul((4, 3, 5), (4, 2, 3), True, True)
+    verify_nn_matmul((4, 3, 5), (4, 3, 2), True, False)
+    verify_nn_matmul((4, 5, 3), (4, 2, 3), False, True)
+    # batched matmul with broadcast
+    verify_nn_matmul((4, 5, 3), (1, 2, 3), False, True)
+    verify_nn_matmul((1, 5, 3), (4, 2, 3), False, True)
+    verify_nn_matmul((5, 3), (4, 2, 3), False, True)
+    verify_nn_matmul((4, 5, 3), (2, 3), False, True)
+    verify_nn_matmul((2, 4, 5, 3), (1, 2, 3), False, True)
+    # batched matmul with bias
+    verify_nn_matmul((4, 5, 3), (4, 3, 2), False, False, True)
+    verify_nn_matmul((4, 3, 5), (4, 2, 3), True, True, True)
+    verify_nn_matmul((4, 3, 5), (4, 3, 2), True, False, True)
+    verify_nn_matmul((4, 5, 3), (4, 2, 3), False, True, True)
 
 
 def verify_matmul(sa, sb, transp_a, transp_b):

Reply via email to