This is an automated email from the ASF dual-hosted git repository.
comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 41879b2 [FIX] Bug fix for batch_matmul parameters mismatch (#8785)
41879b2 is described below
commit 41879b2552364f094492470a77a3ec0866b30eae
Author: Chenfan <[email protected]>
AuthorDate: Thu Aug 19 11:15:34 2021 +0800
[FIX] Bug fix for batch_matmul parameters mismatch (#8785)
---
python/tvm/topi/cuda/batch_matmul.py | 13 ++++++++++++-
python/tvm/topi/cuda/batch_matmul_tensorcore.py | 9 +++++++--
python/tvm/topi/rocm/batch_matmul.py | 7 +++++--
3 files changed, 24 insertions(+), 5 deletions(-)
diff --git a/python/tvm/topi/cuda/batch_matmul.py
b/python/tvm/topi/cuda/batch_matmul.py
index 3fc8a58..bd556d2 100644
--- a/python/tvm/topi/cuda/batch_matmul.py
+++ b/python/tvm/topi/cuda/batch_matmul.py
@@ -237,7 +237,9 @@ def schedule_batch_matmul_cublas(_, outs):
@autotvm.register_topi_compute("batch_matmul_int8.cuda")
-def batch_matmul_int8(cfg, x, y, out_shape=None, out_dtype=None):
+def batch_matmul_int8(
+ cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False,
transpose_b=True
+):
"""Batch Matmul operator for int8 on CUDA.
Parameters
@@ -258,11 +260,20 @@ def batch_matmul_int8(cfg, x, y, out_shape=None,
out_dtype=None):
out_dtype : Optional[str]
Specifies the output data type for mixed precision batch matmul.
+ transpose_a : Optional[bool] = False
+ Whether the first tensor is in transposed format.
+
+ transpose_b : Optional[bool] = True
+ Whether the second tensor is in transposed format.
+
Returns
-------
output : tvm.te.Tensor
3-D with shape [batch, M, N]
"""
+ del out_shape
+ # TODO(jcf94): Deal with different transpose combinations
+ assert not transpose_a and transpose_b
if out_dtype is None:
out_dtype = x.dtype
diff --git a/python/tvm/topi/cuda/batch_matmul_tensorcore.py
b/python/tvm/topi/cuda/batch_matmul_tensorcore.py
index a56d3c3..5324302 100644
--- a/python/tvm/topi/cuda/batch_matmul_tensorcore.py
+++ b/python/tvm/topi/cuda/batch_matmul_tensorcore.py
@@ -29,9 +29,14 @@ from .tensor_intrin import (
@autotvm.register_topi_compute("batch_matmul_tensorcore.cuda")
-def batch_matmul_tensorcore(cfg, x, y, out_shape=None, out_dtype=None):
+def batch_matmul_tensorcore(
+ cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False,
transpose_b=True
+):
"""batch matmul tensorcore operator on cuda"""
- # todo: deal with out_shape for broadcast, liuxin.ai
+ # TODO(jcf94): Deal with different transpose combinations
+ assert not transpose_a and transpose_b
+ # TODO(liuxin.ai): Deal with out_shape for broadcast
+ del out_shape
return batch_matmul_tensorcore_cuda(x, y, out_dtype)
diff --git a/python/tvm/topi/rocm/batch_matmul.py
b/python/tvm/topi/rocm/batch_matmul.py
index 7f35f4b..53b51ee 100644
--- a/python/tvm/topi/rocm/batch_matmul.py
+++ b/python/tvm/topi/rocm/batch_matmul.py
@@ -23,7 +23,9 @@ from ..utils import get_const_tuple
@autotvm.register_topi_compute("batch_matmul_rocblas.rocm")
-def batch_matmul_rocblas(cfg, x, y, out_shape=None):
+def batch_matmul_rocblas(
+ cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False,
transpose_b=True
+):
"""Computes matrix multiplication of `x` and `y` via rocblas when
`x` and `y` are batched matrices.
@@ -40,12 +42,13 @@ def batch_matmul_rocblas(cfg, x, y, out_shape=None):
output : tvm.te.Tensor
3-D with shape [batch, M, N]
"""
+ del out_dtype
batch, M, K = get_const_tuple(x.shape)
_, N, _ = get_const_tuple(y.shape)
if out_shape is not None:
assert out_shape[0] == batch, "Input and output batch sizes must match"
assert out_shape[1] == M and out_shape[2] == N, "Invalid output shape"
- result = rocblas.batch_matmul(x, y, False, True)
+ result = rocblas.batch_matmul(x, y, transpose_a, transpose_b)
cfg.add_flop(batch * M * N * K * 2)
return result