comaniac commented on a change in pull request #8527:
URL: https://github.com/apache/tvm/pull/8527#discussion_r675315468



##########
File path: python/tvm/topi/nn/batch_matmul.py
##########
@@ -21,73 +21,117 @@
 from ..utils import get_const_tuple
 
 
-def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout="", 
out_dtype=None):
-    """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
+def batch_matmul(
+    tensor_a,
+    tensor_b,
+    oshape=None,
+    out_dtype=None,
+    transpose_a=False,
+    transpose_b=True,
+    auto_scheduler_rewritten_layout="",
+):
+    """Computes batch matrix multiplication of `A` and `B` when `A` and `B` are
     data in batch. Supports broadcasting for batch dimension.
 
+    The A & B can be transposed. For legacy reason, we use NT format(tensor_a 
non-transposed
+    and tensor_b transposed) by default.
+
     Parameters
     ----------
-    x : tvm.te.Tensor
-        3-D with shape [batch, M, K]
+    tensor_a : tvm.te.Tensor
+        3-D with shape [batch, M, K] or [batch, K, M]
 
-    y : tvm.te.Tensor
-        3-D with shape [batch, N, K]
+    tensor_b : tvm.te.Tensor
+        3-D with shape [batch, K, N] or [batch, N, K]
 
     oshape : List[Optional]
         Explicit intended output shape of the computation. Can be useful in 
cases
         with dynamic input shapes.
 
-    auto_scheduler_rewritten_layout: str = ""
+    auto_scheduler_rewritten_layout: Optional[str] = ""
         The layout after auto-scheduler's layout rewrite pass.
 
+    out_dtype : Optional[str]
+        Specifies the output data type for mixed precision batch matmul
+
+    transpose_a : Optional[bool] = False
+        Whether the data tensor is in transposed format.
+
+    transpose_b : Optional[bool] = True
+        Whether the weight tensor is in transposed format.
+
     Returns
     -------
     output : tvm.te.Tensor
         3-D with shape [batch, M, N]
     """
-    x_shape = get_const_tuple(x.shape)
+    assert len(tensor_a.shape) == 3, "only support 3-dim batch_matmul"
+    if transpose_a:
+        XB, XK, XI = get_const_tuple(tensor_a.shape)
+    else:
+        XB, XI, XK = get_const_tuple(tensor_a.shape)
     if auto_scheduler_rewritten_layout:
         # Infer shape for the rewritten layout
-        y_shape = auto_scheduler.get_shape_from_rewritten_layout(
-            auto_scheduler_rewritten_layout, ["b", "j", "k"]
+        YB, YK, YJ = auto_scheduler.get_shape_from_rewritten_layout(
+            auto_scheduler_rewritten_layout, ["b", "k", "j"]
         )
-        auto_scheduler.remove_index_check(y)
+        auto_scheduler.remove_index_check(tensor_b)
     else:
-        y_shape = get_const_tuple(y.shape)
-    assert len(x_shape) == 3 and len(y_shape) == 3, "only support 3-dim 
batch_matmul"
+        assert len(tensor_b.shape) == 3, "only support 3-dim batch_matmul"
+        if transpose_b:
+            YB, YJ, YK = get_const_tuple(tensor_b.shape)
+        else:
+            YB, YK, YJ = get_const_tuple(tensor_b.shape)
 
-    XB = x_shape[0]
-    YB = y_shape[0]
-    _, M, K = x.shape
-    k = te.reduce_axis((0, K), name="k")
+    assert XK == YK or isinstance(YK, tvm.tir.expr.Var), "shapes of x and y is 
inconsistent"
+    k = te.reduce_axis((0, XK), name="k")
     if oshape is None:
         assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match"
-        assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistent"
-        batch = te.max(XB, YB)
-        N = y.shape[1]
-        oshape = (batch, M, N)
-
-    if out_dtype is None or out_dtype == x.dtype:
-        output = te.compute(
-            oshape,
-            lambda b, i, j: te.sum(
-                x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], 
axis=k
-            ),
-            tag="batch_matmul",
-            attrs={"layout_free_placeholders": [y]},
+        batch = (
+            tvm.tir.Any()
+            if isinstance(XB, tvm.tir.expr.Var) or isinstance(YB, 
tvm.tir.expr.Var)
+            else te.max(XB, YB)
         )
-    else:
-        output = te.compute(
-            oshape,
-            lambda b, i, j: te.sum(
-                x[b if XB != 1 else 0, i, k].astype(out_dtype)
-                * y[b if YB != 1 else 0, j, k].astype(out_dtype),
-                axis=k,
-            ),
-            tag="batch_matmul",
-            attrs={"layout_free_placeholders": [y]},
+        oshape = (batch, XI, YJ)
+    if out_dtype is None:
+        out_dtype = tensor_a.dtype

Review comment:
       That sounds not right...what if A is in fp16 and B is in fp32?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to