comaniac commented on a change in pull request #8527:
URL: https://github.com/apache/tvm/pull/8527#discussion_r675000279
##########
File path: python/tvm/topi/nn/batch_matmul.py
##########
@@ -21,73 +21,117 @@
from ..utils import get_const_tuple
-def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout="",
out_dtype=None):
- """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
+def batch_matmul(
+ tensor_a,
+ tensor_b,
+ oshape=None,
+ out_dtype=None,
+ transpose_a=False,
+ transpose_b=True,
+ auto_scheduler_rewritten_layout="",
+):
+ """Computes batch matrix multiplication of `A` and `B` when `A` and `B` are
Review comment:
ditto
##########
File path: python/tvm/relay/op/nn/_nn.py
##########
@@ -1276,24 +1276,34 @@ def dense_pack_shape_func(attrs, inputs, _):
@script
-def _batch_matmul_shape_func(data_shape, weight_shape):
- out = output_tensor((data_shape.shape[0],), "int64")
- for i in const_range(out.shape[0] - 1):
- if i == 0:
- out[i] = max(data_shape[i], weight_shape[i])
- else:
- out[i] = data_shape[i]
- out[out.shape[0] - 1] = weight_shape[weight_shape.shape[0] - 2]
+def _batch_matmul_shape_func(tensor_a_shape, tensor_b_shape, transpose_a,
transpose_b):
+ out = output_tensor((tensor_a_shape.shape[0],), "int64")
+ out[0] = max(tensor_a_shape[0], tensor_b_shape[0])
+ if transpose_a:
+ out[1] = tensor_a_shape[2]
+ else:
+ out[1] = tensor_a_shape[1]
+ if transpose_b:
+ out[2] = tensor_b_shape[1]
+ else:
+ out[2] = tensor_b_shape[2]
Review comment:
```suggestion
out[1] = tensor_a_shape[2 if transpose_a else 1]
out[2] = tensor_a_shape[1 if transpose_b else 2]
```
##########
File path: python/tvm/topi/nn/batch_matmul.py
##########
@@ -21,73 +21,117 @@
from ..utils import get_const_tuple
-def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout="",
out_dtype=None):
- """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
+def batch_matmul(
+ tensor_a,
+ tensor_b,
+ oshape=None,
+ out_dtype=None,
+ transpose_a=False,
+ transpose_b=True,
+ auto_scheduler_rewritten_layout="",
+):
+ """Computes batch matrix multiplication of `A` and `B` when `A` and `B` are
data in batch. Supports broadcasting for batch dimension.
+ The A & B can be transposed. For legacy reason, we use NT format(tensor_a
non-transposed
+ and tensor_b transposed) by default.
+
Parameters
----------
- x : tvm.te.Tensor
- 3-D with shape [batch, M, K]
+ tensor_a : tvm.te.Tensor
+ 3-D with shape [batch, M, K] or [batch, K, M]
- y : tvm.te.Tensor
- 3-D with shape [batch, N, K]
+ tensor_b : tvm.te.Tensor
+ 3-D with shape [batch, K, N] or [batch, N, K]
oshape : List[Optional]
Explicit intended output shape of the computation. Can be useful in
cases
with dynamic input shapes.
- auto_scheduler_rewritten_layout: str = ""
+ auto_scheduler_rewritten_layout: Optional[str] = ""
The layout after auto-scheduler's layout rewrite pass.
+ out_dtype : Optional[str]
+ Specifies the output data type for mixed precision batch matmul
+
+ transpose_a : Optional[bool] = False
+ Whether the data tensor is in transposed format.
+
+ transpose_b : Optional[bool] = True
+ Whether the weight tensor is in transposed format.
+
Returns
-------
output : tvm.te.Tensor
3-D with shape [batch, M, N]
"""
- x_shape = get_const_tuple(x.shape)
+ assert len(tensor_a.shape) == 3, "only support 3-dim batch_matmul"
+ if transpose_a:
+ XB, XK, XI = get_const_tuple(tensor_a.shape)
+ else:
+ XB, XI, XK = get_const_tuple(tensor_a.shape)
if auto_scheduler_rewritten_layout:
# Infer shape for the rewritten layout
- y_shape = auto_scheduler.get_shape_from_rewritten_layout(
- auto_scheduler_rewritten_layout, ["b", "j", "k"]
+ YB, YK, YJ = auto_scheduler.get_shape_from_rewritten_layout(
+ auto_scheduler_rewritten_layout, ["b", "k", "j"]
)
- auto_scheduler.remove_index_check(y)
+ auto_scheduler.remove_index_check(tensor_b)
else:
- y_shape = get_const_tuple(y.shape)
- assert len(x_shape) == 3 and len(y_shape) == 3, "only support 3-dim
batch_matmul"
+ assert len(tensor_b.shape) == 3, "only support 3-dim batch_matmul"
Review comment:
```suggestion
assert len(tensor_b.shape) == 3, "only support 3-dim tensor_b"
```
##########
File path: src/relay/transforms/combine_parallel_batch_matmul.cc
##########
@@ -86,7 +93,7 @@ class ParallelBatchMatmulCombiner : public ParallelOpCombiner
{
const auto* origin_attrs = branches[0][0]->attrs.as<BatchMatmulAttrs>();
ICHECK(origin_attrs);
- return Downcast<Call>(MakeBatchMatmul(data, new_weight,
origin_attrs->out_dtype));
+ return Downcast<Call>(MakeBatchMatmul(data, new_weight,
origin_attrs->out_dtype, false, true));
Review comment:
Better to still use `attrs_a->transpose_a`, etc.
##########
File path: python/tvm/topi/nn/batch_matmul.py
##########
@@ -21,73 +21,117 @@
from ..utils import get_const_tuple
-def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout="",
out_dtype=None):
- """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
+def batch_matmul(
+ tensor_a,
+ tensor_b,
+ oshape=None,
+ out_dtype=None,
+ transpose_a=False,
+ transpose_b=True,
+ auto_scheduler_rewritten_layout="",
+):
+ """Computes batch matrix multiplication of `A` and `B` when `A` and `B` are
data in batch. Supports broadcasting for batch dimension.
+ The A & B can be transposed. For legacy reason, we use NT format(tensor_a
non-transposed
+ and tensor_b transposed) by default.
+
Parameters
----------
- x : tvm.te.Tensor
- 3-D with shape [batch, M, K]
+ tensor_a : tvm.te.Tensor
+ 3-D with shape [batch, M, K] or [batch, K, M]
- y : tvm.te.Tensor
- 3-D with shape [batch, N, K]
+ tensor_b : tvm.te.Tensor
+ 3-D with shape [batch, K, N] or [batch, N, K]
oshape : List[Optional]
Explicit intended output shape of the computation. Can be useful in
cases
with dynamic input shapes.
- auto_scheduler_rewritten_layout: str = ""
+ auto_scheduler_rewritten_layout: Optional[str] = ""
The layout after auto-scheduler's layout rewrite pass.
+ out_dtype : Optional[str]
+ Specifies the output data type for mixed precision batch matmul
+
+ transpose_a : Optional[bool] = False
+ Whether the data tensor is in transposed format.
+
+ transpose_b : Optional[bool] = True
+ Whether the weight tensor is in transposed format.
+
Returns
-------
output : tvm.te.Tensor
3-D with shape [batch, M, N]
"""
- x_shape = get_const_tuple(x.shape)
+ assert len(tensor_a.shape) == 3, "only support 3-dim batch_matmul"
Review comment:
```suggestion
assert len(tensor_a.shape) == 3, "only support 3-dim tensor_a"
```
##########
File path: python/tvm/topi/nn/batch_matmul.py
##########
@@ -21,73 +21,117 @@
from ..utils import get_const_tuple
-def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout="",
out_dtype=None):
- """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
+def batch_matmul(
+ tensor_a,
+ tensor_b,
+ oshape=None,
+ out_dtype=None,
+ transpose_a=False,
+ transpose_b=True,
+ auto_scheduler_rewritten_layout="",
+):
+ """Computes batch matrix multiplication of `A` and `B` when `A` and `B` are
data in batch. Supports broadcasting for batch dimension.
+ The A & B can be transposed. For legacy reason, we use NT format(tensor_a
non-transposed
+ and tensor_b transposed) by default.
+
Parameters
----------
- x : tvm.te.Tensor
- 3-D with shape [batch, M, K]
+ tensor_a : tvm.te.Tensor
+ 3-D with shape [batch, M, K] or [batch, K, M]
- y : tvm.te.Tensor
- 3-D with shape [batch, N, K]
+ tensor_b : tvm.te.Tensor
+ 3-D with shape [batch, K, N] or [batch, N, K]
oshape : List[Optional]
Explicit intended output shape of the computation. Can be useful in
cases
with dynamic input shapes.
- auto_scheduler_rewritten_layout: str = ""
+ auto_scheduler_rewritten_layout: Optional[str] = ""
The layout after auto-scheduler's layout rewrite pass.
+ out_dtype : Optional[str]
+ Specifies the output data type for mixed precision batch matmul
+
+ transpose_a : Optional[bool] = False
+ Whether the data tensor is in transposed format.
+
+ transpose_b : Optional[bool] = True
+ Whether the weight tensor is in transposed format.
+
Returns
-------
output : tvm.te.Tensor
3-D with shape [batch, M, N]
"""
- x_shape = get_const_tuple(x.shape)
+ assert len(tensor_a.shape) == 3, "only support 3-dim batch_matmul"
+ if transpose_a:
+ XB, XK, XI = get_const_tuple(tensor_a.shape)
+ else:
+ XB, XI, XK = get_const_tuple(tensor_a.shape)
if auto_scheduler_rewritten_layout:
# Infer shape for the rewritten layout
- y_shape = auto_scheduler.get_shape_from_rewritten_layout(
- auto_scheduler_rewritten_layout, ["b", "j", "k"]
+ YB, YK, YJ = auto_scheduler.get_shape_from_rewritten_layout(
+ auto_scheduler_rewritten_layout, ["b", "k", "j"]
)
- auto_scheduler.remove_index_check(y)
+ auto_scheduler.remove_index_check(tensor_b)
else:
- y_shape = get_const_tuple(y.shape)
- assert len(x_shape) == 3 and len(y_shape) == 3, "only support 3-dim
batch_matmul"
+ assert len(tensor_b.shape) == 3, "only support 3-dim batch_matmul"
+ if transpose_b:
+ YB, YJ, YK = get_const_tuple(tensor_b.shape)
+ else:
+ YB, YK, YJ = get_const_tuple(tensor_b.shape)
- XB = x_shape[0]
- YB = y_shape[0]
- _, M, K = x.shape
- k = te.reduce_axis((0, K), name="k")
+ assert XK == YK or isinstance(YK, tvm.tir.expr.Var), "shapes of x and y is
inconsistent"
Review comment:
```suggestion
assert XK == YK or isinstance(YK, tvm.tir.expr.Var), "shapes of x and y
are inconsistent"
```
##########
File path: src/relay/qnn/op/batch_matmul.cc
##########
@@ -84,7 +84,9 @@ Expr MakeQuantizedBatchMatmul(Expr x, Expr y, Expr
x_zero_point, Expr y_zero_poi
Expr BatchMatmulFirstTerm(const Expr& quantized_x, const Expr& quantized_y,
const BatchMatmulAttrs* attrs) {
- return MakeBatchMatmul(quantized_x, quantized_y, attrs->out_dtype);
+ ICHECK(attrs->transpose_a == false && attrs->transpose_b == true)
+ << "Currently qnn.batch_matmul only support NT format.";
Review comment:
```suggestion
<< "Currently qnn.batch_matmul only supports (transpose_a=false,
transpose_b=true).";
```
##########
File path: python/tvm/topi/nn/batch_matmul.py
##########
@@ -21,73 +21,117 @@
from ..utils import get_const_tuple
-def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout="",
out_dtype=None):
- """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
+def batch_matmul(
+ tensor_a,
+ tensor_b,
+ oshape=None,
+ out_dtype=None,
+ transpose_a=False,
+ transpose_b=True,
+ auto_scheduler_rewritten_layout="",
+):
+ """Computes batch matrix multiplication of `A` and `B` when `A` and `B` are
data in batch. Supports broadcasting for batch dimension.
+ The A & B can be transposed. For legacy reason, we use NT format(tensor_a
non-transposed
+ and tensor_b transposed) by default.
+
Parameters
----------
- x : tvm.te.Tensor
- 3-D with shape [batch, M, K]
+ tensor_a : tvm.te.Tensor
+ 3-D with shape [batch, M, K] or [batch, K, M]
- y : tvm.te.Tensor
- 3-D with shape [batch, N, K]
+ tensor_b : tvm.te.Tensor
+ 3-D with shape [batch, K, N] or [batch, N, K]
oshape : List[Optional]
Explicit intended output shape of the computation. Can be useful in
cases
with dynamic input shapes.
- auto_scheduler_rewritten_layout: str = ""
+ auto_scheduler_rewritten_layout: Optional[str] = ""
The layout after auto-scheduler's layout rewrite pass.
+ out_dtype : Optional[str]
+ Specifies the output data type for mixed precision batch matmul
+
+ transpose_a : Optional[bool] = False
+ Whether the data tensor is in transposed format.
+
+ transpose_b : Optional[bool] = True
+ Whether the weight tensor is in transposed format.
+
Returns
-------
output : tvm.te.Tensor
3-D with shape [batch, M, N]
"""
- x_shape = get_const_tuple(x.shape)
+ assert len(tensor_a.shape) == 3, "only support 3-dim batch_matmul"
+ if transpose_a:
+ XB, XK, XI = get_const_tuple(tensor_a.shape)
+ else:
+ XB, XI, XK = get_const_tuple(tensor_a.shape)
if auto_scheduler_rewritten_layout:
# Infer shape for the rewritten layout
- y_shape = auto_scheduler.get_shape_from_rewritten_layout(
- auto_scheduler_rewritten_layout, ["b", "j", "k"]
+ YB, YK, YJ = auto_scheduler.get_shape_from_rewritten_layout(
+ auto_scheduler_rewritten_layout, ["b", "k", "j"]
)
- auto_scheduler.remove_index_check(y)
+ auto_scheduler.remove_index_check(tensor_b)
else:
- y_shape = get_const_tuple(y.shape)
- assert len(x_shape) == 3 and len(y_shape) == 3, "only support 3-dim
batch_matmul"
+ assert len(tensor_b.shape) == 3, "only support 3-dim batch_matmul"
+ if transpose_b:
+ YB, YJ, YK = get_const_tuple(tensor_b.shape)
+ else:
+ YB, YK, YJ = get_const_tuple(tensor_b.shape)
- XB = x_shape[0]
- YB = y_shape[0]
- _, M, K = x.shape
- k = te.reduce_axis((0, K), name="k")
+ assert XK == YK or isinstance(YK, tvm.tir.expr.Var), "shapes of x and y is
inconsistent"
+ k = te.reduce_axis((0, XK), name="k")
if oshape is None:
assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match"
- assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistent"
- batch = te.max(XB, YB)
- N = y.shape[1]
- oshape = (batch, M, N)
-
- if out_dtype is None or out_dtype == x.dtype:
- output = te.compute(
- oshape,
- lambda b, i, j: te.sum(
- x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k],
axis=k
- ),
- tag="batch_matmul",
- attrs={"layout_free_placeholders": [y]},
+ batch = (
+ tvm.tir.Any()
+ if isinstance(XB, tvm.tir.expr.Var) or isinstance(YB,
tvm.tir.expr.Var)
+ else te.max(XB, YB)
)
- else:
- output = te.compute(
- oshape,
- lambda b, i, j: te.sum(
- x[b if XB != 1 else 0, i, k].astype(out_dtype)
- * y[b if YB != 1 else 0, j, k].astype(out_dtype),
- axis=k,
- ),
- tag="batch_matmul",
- attrs={"layout_free_placeholders": [y]},
+ oshape = (batch, XI, YJ)
+ if out_dtype is None:
+ out_dtype = tensor_a.dtype
Review comment:
Do we need to check whether A and B are in the same dtype?
##########
File path: src/relay/transforms/combine_parallel_batch_matmul.cc
##########
@@ -68,6 +68,13 @@ class ParallelBatchMatmulCombiner : public
ParallelOpCombiner {
// shape[2] is the contraction axis and automatically consistent
// if it were valid batch_matmul ops
+ // This pass only support the original NT format now
+ // TODO(jcf94): Add full support of layout format
+ if (attrs_a->transpose_a == true || attrs_a->transpose_b == false ||
+ attrs_b->transpose_a == true || attrs_b->transpose_b == false) {
+ return false;
Review comment:
This is a required functionality so maybe throw a warning?
##########
File path: python/tvm/relay/op/nn/nn.py
##########
@@ -2137,32 +2137,41 @@ def group_norm(data, gamma, beta, num_groups, axis=1,
epsilon=1e-5, center=True,
return _make.group_norm(data, gamma, beta, num_groups, axis, epsilon,
center, scale)
-def batch_matmul(x, y, out_dtype=""):
+def batch_matmul(tensor_a, tensor_b, out_dtype="", transpose_a=False,
transpose_b=True):
r"""
- Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data
+ Computes batch matrix multiplication of `A` and `B` when `A` and `B` are
data
in batch.
+
+ The A & B can be transposed. For legacy reason, we use NT format(tensor_a
non-transposed
+ and tensor_b transposed) by default.
Review comment:
```suggestion
Compute batch matrix multiplication of `tensor_a` and `tensor_b`.
Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we
use NT format
(transpose_a=False, transpose_b=True) by default.
```
##########
File path: python/tvm/topi/x86/batch_matmul.py
##########
@@ -45,40 +47,18 @@ def batch_matmul(cfg, x, y, out_shape=None, out_dtype=None):
output : tvm.te.Tensor
3-D with shape [batch, M, N]
"""
- assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim
batch_matmul"
- XB, M, XK = get_const_tuple(x.shape)
- YB, N, YK = get_const_tuple(y.shape)
- assert (XB == YB) or (YB == 1) or (XB == 1), "batch dimension doesn't
match"
- assert XK == YK, "shapes of x and y is inconsistent"
- B = te.max(XB, YB)
- K = XK
- if out_shape is not None:
- assert out_shape[0] == B, "got invalid output shape"
- assert out_shape[1] == M, "got invalid output shape"
- assert out_shape[2] == N, "got invalid output shape"
- if cfg.is_fallback:
- _default_batch_matmul_config(cfg, M, N, K)
-
- k = te.reduce_axis((0, K), name="k")
- if out_dtype is None or out_dtype == x.dtype:
- C = te.compute(
- (B, M, N),
- lambda b, i, j: te.sum(
- x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k],
axis=k
- ),
- tag="batch_matmul",
- )
- else:
- C = te.compute(
- (B, M, N),
- lambda b, i, j: te.sum(
- x[b if XB != 1 else 0, i, k].astype(out_dtype)
- * y[b if YB != 1 else 0, j, k].astype(out_dtype),
- axis=k,
- ),
- tag="batch_matmul",
- )
- return C
+ if cfg.is_fallback and not transpose_a and transpose_b:
Review comment:
1. Add a comment to explain why only NT has a default config.
2. It implies that other modes (NN, TN, TT) won't have a config and will
encounter an error. It might be better to have a proper message in the AutoTVM
NT schedule?
##########
File path: python/tvm/topi/x86/batch_matmul.py
##########
@@ -20,12 +20,14 @@
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity
from tvm.contrib import cblas, mkl
-from .. import generic
+from .. import generic, nn
from ..utils import traverse_inline, get_const_tuple, get_max_power2_factor
@autotvm.register_topi_compute("batch_matmul.x86")
-def batch_matmul(cfg, x, y, out_shape=None, out_dtype=None):
+def batch_matmul(
+ cfg, tensor_a, tensor_b, out_shape=None, out_dtype=None,
transpose_a=False, transpose_b=True
+):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
Review comment:
Update docstring and parameters
##########
File path: src/relay/op/nn/nn.cc
##########
@@ -959,10 +962,11 @@ are data in batch.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
- .add_argument("x", "3D Tensor", "First input.")
- .add_argument("y", "3D Tensor", "Second input.")
+ .add_argument("tensor_a", "3D Tensor", "First input.")
Review comment:
Update the description
##########
File path: src/relay/op/nn/nn.cc
##########
@@ -932,15 +932,18 @@ If the input has size k on axis 1, then both gamma and
beta have shape (k,).
.set_support_level(1)
.add_type_rel("GroupNorm", GroupNormRel);
-// relay.nn.batch_matmul
+// ------------------- relay.nn.batch_matmul
Review comment:
What's this for?
##########
File path: src/relay/op/nn/nn.cc
##########
@@ -959,10 +962,11 @@ are data in batch.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
- .add_argument("x", "3D Tensor", "First input.")
- .add_argument("y", "3D Tensor", "Second input.")
+ .add_argument("tensor_a", "3D Tensor", "First input.")
+ .add_argument("tensor_b", "3D Tensor", "Second input.")
.set_support_level(10)
.add_type_rel("BatchMatmul", BatchMatmulRel<BatchMatmulAttrs>);
+// ------------------- relay.nn.batch_matmul
Review comment:
ditto?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]